diff --git a/DESCRIPTION b/DESCRIPTION index c587213..6eef2bd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: LDATree Title: Classification Trees with Linear Discriminant Analysis at Terminal Nodes -Version: 0.1.2 +Version: 0.1.2.9001 Authors@R: person("Siyu", "Wang", , "swang739@wisc.edu", role = c("cre", "aut", "cph"), comment = c(ORCID = "0009-0005-2098-7089")) @@ -13,11 +13,12 @@ License: MIT + file LICENSE URL: https://github.com/Moran79/LDATree, http://iamwangsiyu.com/LDATree/ BugReports: https://github.com/Moran79/LDATree/issues Imports: + folda, ggplot2, - lifecycle, + grDevices, magrittr, - scales, stats, + utils, visNetwork Encoding: UTF-8 LazyData: true diff --git a/NAMESPACE b/NAMESPACE index 3505cf0..75ad3e7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,24 +1,9 @@ # Generated by roxygen2: do not edit by hand -S3method(plot,SingleTreee) S3method(plot,Treee) -S3method(predict,SingleTreee) S3method(predict,Treee) -S3method(predict,ldaGSVD) -S3method(print,Treee) S3method(print,TreeeNode) -S3method(print,ldaGSVD) S3method(summary,TreeeNode) export(Treee) -export(ldaGSVD) -importFrom(lifecycle,deprecated) importFrom(magrittr,"%>%") -importFrom(stats,.getXlevels) -importFrom(stats,delete.response) -importFrom(stats,model.frame) -importFrom(stats,model.matrix) -importFrom(stats,model.matrix.lm) importFrom(stats,predict) -importFrom(stats,quantile) -importFrom(stats,sd) -importFrom(stats,terms) diff --git a/R/LDATree-package.R b/R/LDATree-package.R index 9576403..2cfbe57 100644 --- a/R/LDATree-package.R +++ b/R/LDATree-package.R @@ -2,16 +2,7 @@ "_PACKAGE" ## usethis namespace: start -#' @importFrom lifecycle deprecated #' @importFrom magrittr %>% -#' @importFrom stats .getXlevels -#' @importFrom stats delete.response -#' @importFrom stats model.frame -#' @importFrom stats model.matrix -#' @importFrom stats model.matrix.lm #' @importFrom stats predict -#' @importFrom stats quantile -#' @importFrom stats sd -#' @importFrom stats terms ## usethis namespace: end NULL diff --git a/R/Treee.R b/R/Treee.R index 4732ecb..5d79e48 100644 --- a/R/Treee.R +++ b/R/Treee.R @@ -1,105 +1,119 @@ -#' Classification trees with Linear Discriminant Analysis terminal nodes +#' Classification Trees with Uncorrelated Linear Discriminant Analysis Terminal +#' Nodes #' -#' @description `r lifecycle::badge('experimental')` Fit an LDATree model. +#' This function fits a classification tree where each node has a Uncorrelated +#' Linear Discriminant Analysis (ULDA) model. It can also handle missing values +#' and perform downsampling. The resulting tree can be pruned either through +#' pre-pruning or post-pruning methods. #' -#' @details +#' @param datX A data frame of predictor variables. +#' @param response A vector of response values corresponding to `datX`. +#' @param ldaType A character string specifying the type of LDA to use. Options +#' are `"forward"` for forward ULDA or `"all"` for full ULDA. Default is +#' `"forward"`. +#' @param nodeModel A character string specifying the type of model used in each +#' node. Options are `"ULDA"` for Uncorrelated LDA, or `"mode"` for predicting +#' based on the most frequent class. Default is `"ULDA"`. +#' @param pruneMethod A character string specifying the pruning method. `"pre"` +#' performs pre-pruning based on p-value thresholds, and `"post"` performs +#' cross-validation-based post-pruning. Default is `"pre"`. +#' @param numberOfPruning An integer specifying the number of folds for +#' cross-validation during post-pruning. Default is `10`. +#' @param maxTreeLevel An integer controlling the maximum depth of the tree. +#' Increasing this value allows for deeper trees with more nodes. Default is +#' `20`. +#' @param minNodeSize An integer controlling the minimum number of samples +#' required in a node. Setting a higher value may lead to earlier stopping and +#' smaller trees. If not specified, it defaults to one plus the number of +#' response classes. +#' @param pThreshold A numeric value used as a threshold for pre-pruning based +#' on p-values. Lower values result in more conservative trees. If not +#' specified, defaults to `0.01` for pre-pruning and `0.51` for post-pruning. +#' @param prior A numeric vector of prior probabilities for each class. If +#' `NULL`, the prior is automatically calculated from the data. +#' @param misClassCost A square matrix \eqn{C}, where each element \eqn{C_{ij}} +#' represents the cost of classifying an observation into class \eqn{i} given +#' that it truly belongs to class \eqn{j}. If `NULL`, a default matrix with +#' equal misclassification costs for all class pairs is used. Default is +#' `NULL`. +#' @param missingMethod A character string specifying how missing values should +#' be handled. Options include `'mean'`, `'median'`, `'meanFlag'`, +#' `'medianFlag'` for numerical variables, and `'mode'`, `'modeFlag'`, +#' `'newLevel'` for factor variables. `'Flag'` options indicate whether a +#' missing flag is added, while `'newLevel'` replaces missing values with a +#' new factor level. +#' @param kSample An integer specifying the number of samples to use for +#' downsampling during tree construction. Set to `-1` to disable downsampling. +#' @param verbose A logical value. If `TRUE`, progress messages and detailed +#' output are printed during tree construction and pruning. Default is +#' `FALSE`. #' -#' Unlike other classification trees, LDATree integrates LDA throughout the -#' entire tree-growing process. Here is a breakdown of its distinctive features: -#' * The tree searches for the best binary split based on sample quantiles of the first linear discriminant score. +#' @returns An object of class `Treee` containing the fitted tree, which is a +#' list of nodes, each an object of class `TreeeNode`. Each `TreeeNode` +#' contains: +#' * `currentIndex`: The node index in the tree. +#' * `currentLevel`: The depth of the current node in the tree. +#' * `idxRow`, `idxCol`: Row and column indices indicating which part of the original data was used for this node. +#' * `currentLoss`: The training error for this node. +#' * `accuracy`: The training accuracy for this node. +#' * `stopInfo`: Information on why the node stopped growing. +#' * `proportions`: The observed frequency of each class in this node. +#' * `prior`: The (adjusted) class prior probabilities used for ULDA or mode prediction. +#' * `misClassCost`: The misclassification cost matrix used in this node. +#' * `parent`: The index of the parent node. +#' * `children`: A vector of indices of this node’s direct children. +#' * `splitFun`: The splitting function used for this node. +#' * `nodeModel`: Indicates the model fitted at the node (`'ULDA'` or `'mode'`). +#' * `nodePredict`: The fitted model at the node, either a ULDA object or the plurality class. +#' * `alpha`: The p-value from a two-sample t-test used to evaluate the strength of the split. +#' * `childrenTerminal`: A vector of indices representing the terminal nodes that are descendants of this node. +#' * `childrenTerminalLoss`: The total training error accumulated from all nodes listed in `childrenTerminal`. #' -#' * An LDA/GSVD model is fitted for each terminal node (For more details, refer to [ldaGSVD()]). -#' -#' * Missing values can be imputed using the mean, median, or mode, with optional missing flags available. -#' -#' * By default, the tree employs a direct-stopping rule. However, cross-validation using the alpha-pruning from CART is also provided. -#' -#' @param formula an object of class [formula], which has the form `class ~ x1 + -#' x2 + ...` -#' @param data a data frame that contains both predictors and the response. -#' Missing values are allowed in predictors but not in the response. -#' @param missingMethod Missing value solutions for numerical variables and -#' factor variables. `'mean'`, `'median'`, `'meanFlag'`, `'medianFlag'` are -#' available for numerical variables. `'mode'`, `'modeFlag'`, `'newLevel'` are -#' available for factor variables. The word `'Flag'` in the methods indicates -#' whether a missing flag is added or not. The `'newLevel'` method means that -#' all missing values are replaced with a new level rather than imputing them -#' to another existing value. -#' @param maxTreeLevel controls the largest tree size possible for either a -#' direct-stopping tree or a CV-pruned tree. Adding one extra level (depth) -#' introduces an additional layer of nodes at the bottom of the current tree. -#' e.g., when the maximum level is 1 (or 2), the maximum tree size is 3 (or -#' 7). -#' @param minNodeSize controls the minimum node size. Think carefully before -#' changing this value. Setting a large number might result in early stopping -#' and reduced accuracy. By default, it's set to one plus the number of -#' classes in the response variable. -#' @param verbose a logical. If TRUE, the function provides additional -#' diagnostic messages or detailed output about its progress or internal -#' workings. Default is FALSE, where the function runs silently without -#' additional output. -#' -#' @returns An object of class `Treee` containing the following components: -#' * `formula`: the formula passed to the [Treee()] -#' * `treee`: a list of all the tree nodes, and each node is an object of class `TreeeNode`. -#' * `missingMethod`: the missingMethod passed to the [Treee()] -#' -#' An object of class `TreeeNode` containing the following components: -#' * `currentIndex`: the node index of the current node -#' * `currentLevel`: the level of the current node in the tree -#' * `idxRow`, `idxCol`: the row and column indices showing which portion of data is used in the current node -#' * `currentLoss`: ? -#' * `accuracy`: the training accuracy of the current node -#' * `stopFlag`: ? -#' * `proportions`: shows the observed frequency for each class -#' * `parent`: the node index of its parent -#' * `children`: the node indices of its direct children (not including its children's children) -#' * `misReference`: a data frame, serves as the reference for missing value imputation -#' * `splitFun`: ? -#' * `nodeModel`: one of `'mode'` or `'LDA'`. It shows the type of predictive model fitted in the current node -#' * `nodePredict`: the fitted predictive model in the current node. It is an object of class `ldaGSVD` if LDA is fitted. If `nodeModel = 'mode'`, then it is a vector of length one, showing the plurality class. #' @export #' +#' @references Wang, S. (2024). A New Forward Discriminant Analysis Framework +#' Based On Pillai's Trace and ULDA. \emph{arXiv preprint arXiv:2409.03136}. +#' Available at \url{https://arxiv.org/abs/2409.03136}. +#' #' @examples -#' fit <- Treee(Species~., data = iris) +#' fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) #' # Use cross-validation to prune the tree -#' fitCV <- Treee(Species~., data = iris) -#' # prediction -#' predict(fit,iris) -#' # plot the overall tree -#' plot(fit) -#' # plot a certain node -#' plot(fit, iris, node = 1) +#' fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE) +#' head(predict(fit, iris)) # prediction +#' plot(fit) # plot the overall tree +#' plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node Treee <- function(datX, response, - ldaType = c("step", "all"), - nodeModel = c("LDA", "mode"), - missingMethod = c("medianFlag", "newLevel"), - prior = NULL, - misClassCost = NULL, - pruneMethod = c("post", "pre", "pre-post"), - numberOfPruning = 10, + ldaType = c("forward", "all"), + nodeModel = c("ULDA", "mode"), + pruneMethod = c("pre", "post"), + numberOfPruning = 10L, maxTreeLevel = 20L, minNodeSize = NULL, - pThreshold = 0.1, - verbose = TRUE, - kSample = 1e7){ + pThreshold = NULL, + prior = NULL, + misClassCost = NULL, + missingMethod = c("medianFlag", "newLevel"), + kSample = -1, + verbose = TRUE){ # Change verbose to FALSE before CRAN submission # Standardize the Arguments ----------------------------------------------- - response <- as.factor(response) # make it a factor - ldaType <- match.arg(ldaType, c("step", "all")) - nodeModel <- match.arg(nodeModel, c("LDA", "mode")) - missingMethod <- c(match.arg(missingMethod[1], c("mean", "median", "meanFlag", "medianFlag")), - match.arg(missingMethod[2], c("mode", "modeFlag", "newLevel"))) - prior <- checkPriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, internal = TRUE) - pruneMethod <- match.arg(pruneMethod, c("post", "pre", "pre-post")) - if(is.null(minNodeSize)) minNodeSize <- nlevels(response) + 1 # minNodeSize: If not specified, set to J+1 - - #> Does not support ordered factors - for(i in seq_along(datX)){ - if("ordered" %in% class(datX[,i])) class(datX[,i]) <- "factor" + datX <- data.frame(datX) # change to data.frame, remove the potential tibble attribute + for(i in seq_along(datX)){ # remove ordered factors + if(inherits(datX[,i], c("ordered"))) class(datX[,i]) <- "factor" } + # Remove NAs in the response + idxNonNA <- which(!is.na(response)); response <- droplevels(factor(response[idxNonNA], ordered = FALSE)) + datX <- datX[idxNonNA, , drop = FALSE] + + ldaType <- match.arg(ldaType, c("forward", "all")) + nodeModel <- match.arg(nodeModel, c("ULDA", "mode")) + pruneMethod <- match.arg(pruneMethod, c("pre", "post")) + if(is.null(minNodeSize)) minNodeSize <- nlevels(response) + 1 # minNodeSize: If not specified, set to J+1 + if(is.null(pThreshold)) pThreshold <- ifelse(pruneMethod == "pre", 0.01, 0.51) + priorAndMisClassCost <- updatePriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, insideNode = FALSE) + prior <- priorAndMisClassCost$prior; misClassCost <- priorAndMisClassCost$misClassCost # Build Different Trees --------------------------------------------------- @@ -107,51 +121,39 @@ Treee <- function(datX, response = response, ldaType = ldaType, nodeModel = nodeModel, - missingMethod = missingMethod, - prior = prior, maxTreeLevel = maxTreeLevel, minNodeSize = minNodeSize, - pThreshold = ifelse(pruneMethod == "pre", pThreshold, 0.51), - verbose = verbose, - kSample = kSample) - - if(pruneMethod == "pre-post"){ - #> Step Ahead and prune back, 0.51 as a looser bound - #> Discard due to unnecessary - treeeNow <- updateAlphaInTree(treeeNow) - treeeNow <- pruneTreee(treeeNow, pThreshold) - treeeNow <- dropNodes(treeeNow) - } + pThreshold = pThreshold, + prior = prior, + misClassCost = misClassCost, + missingMethod = missingMethod, + kSample = kSample, + verbose = verbose) if(verbose) cat(paste('\nThe pre-pruned LDA tree is completed. It has', length(treeeNow), 'nodes.\n')) - finalTreee <- structure(list(treee = treeeNow, - missingMethod = missingMethod), class = "Treee") - - # Pruning ----------------------------------------------------------------- if(pruneMethod == "post" & length(treeeNow) > 1){ pruningOutput <- prune(oldTreee = treeeNow, - numberOfPruning = numberOfPruning, datX = datX, response = response, ldaType = ldaType, nodeModel = nodeModel, - missingMethod = missingMethod, - prior = prior, + numberOfPruning = numberOfPruning, maxTreeLevel = maxTreeLevel, minNodeSize = minNodeSize, - pThreshold = 0.51, - verbose = verbose, - kSample = kSample) - - # Add something to the finalTreee - finalTreee$treee <- pruningOutput$treeeNew - finalTreee$CV_Table <- pruningOutput$CV_Table + pThreshold = pThreshold, + prior = prior, + misClassCost = misClassCost, + missingMethod = missingMethod, + kSample = kSample, + verbose = verbose) - if(verbose) cat(paste('\nThe post-pruned tree is completed. It has', length(finalTreee$treee), 'nodes.\n')) + treeeNow <- pruningOutput$treeeNew + attr(treeeNow, "CV_Table") <- pruningOutput$CV_Table + if(verbose) cat(paste('\nThe post-pruned tree is completed. It has', length(treeeNow), 'nodes.\n')) } - return(finalTreee) + return(treeeNow) } diff --git a/R/archivedFuns.R b/R/archivedFuns.R index bc93301..63cdb19 100644 --- a/R/archivedFuns.R +++ b/R/archivedFuns.R @@ -1,42 +1,3 @@ - -# Check for input prior --------------------------------------------------- - -checkPrior <- function(prior, response){ - ## Modified from randomForest.default Line 114 - if (is.null(prior)) { - prior <- table(response) / length(response) # Default: Estimated Prior - } else { - if (length(prior) != nlevels(response)) - stop("length of prior not equal to number of classes") - if (!is.null(names(prior))){ - prior <- prior[findTargetIndex(names(prior), levels(response))] - } - if (any(prior < 0)) stop("prior must be non-negative") - } - return(prior / sum(prior)) -} - - -# Check for input misclassification cost --------------------------------------------------- - -checkMisClassCost <- function(misClassCost, response){ - if (is.null(misClassCost)) { - misClassCost <- (1 - diag(nlevels(response))) # Default: 1-identity - colnames(misClassCost) <- rownames(misClassCost) <- levels(response) - } else { - if (dim(misClassCost)[1] != dim(misClassCost)[2] | dim(misClassCost)[1] != nlevels(response)) - stop("misclassification costs matrix has wrong dimension") - if(!all.equal(colnames(misClassCost), rownames(misClassCost))) - stop("misClassCost: colnames should be the same as rownames") - if (!is.null(colnames(misClassCost))){ - misClassCost <- misClassCost[findTargetIndex(colnames(misClassCost), levels(response)), - findTargetIndex(colnames(misClassCost), levels(response))] - } - } - return(misClassCost) -} - - # constant In Group fix --------------------------------------------------------- fixConstantGroup <- function(data, response, tol = 1e-8){ @@ -230,6 +191,22 @@ dropNodes <- function(treeeList){ return(treeeList) } +# Get tree depth ---------------------------------------------------------- + +getDepth <- function(treee){ + if(class(treee) == "Treee") treee = treee$treee + depthAll <- numeric(length(treee)) + updateList <- seq_along(treee)[-1] + if(length(updateList) != 0){ + for(i in updateList){ + currentNode <- treee[[i]] + depthAll[i] <- depthAll[currentNode$parent] + 1 + } + } + return(depthAll) +} + + # New Pruning (previous prune.R) --------------------------------------------------- #> Usage @@ -266,12 +243,6 @@ makeAlphaMono <- function(treeeList, - - - - - - # Splitting --------------------------------------------------------------- ### Sum of Pillai's trace ### @@ -622,45 +593,6 @@ getSplitFunRandom <- function(datX, response, modelLDA){ } - - - -# fixNewLevel ------------------------------------------------------------- - - -fixNewLevel <- function(datTest, datTrain){ - #> change the shape of test data to the training data - #> and make sure that the dimension of the data is the same as missingRefernce - - nameVarIdx <- match(colnames(datTrain), colnames(datTest)) - if(anyNA(nameVarIdx)){ - #> New columns fix (or Flags): If there are less columns than it should be, - #> add columns with NA - datTest[,colnames(datTrain)[which(is.na(nameVarIdx))]] <- NA - nameVarIdx <- match(colnames(datTrain), colnames(datTest)) - } - datTest <- datTest[,nameVarIdx, drop = FALSE] - #> The columns are the same, now fix new levels - - #> change all characters to factors - idxC <- which(sapply(datTrain, class) == "character") - if(length(idxC) != 0){ - for(i in idxC){ - datTrain[,i] <- as.factor(datTrain[,i]) - } - } - - #> New levels fix - levelReference <- sapply(datTrain, levels, simplify = FALSE) - for(i in which(!sapply(levelReference,is.null))){ - # sapply would lose factor property but left character, why? - datTest[,i] <- factor(datTest[,i], levels = levelReference[[i]]) - } - - return(datTest) -} - - # Get x and response from formula and data -------------------------------- extractXnResponse <- function(formula, data){ @@ -863,5 +795,71 @@ predict.ForestTreee <- function(object, newdata, type = "response", ...){ } +# Variable Selection ------------------------------------------------------ + +getChiSqStat <- function(datX, y){ + sapply(datX, function(x) getChiSqStatHelper(x, y)) +} + +getChiSqStatHelper <- function(x,y){ + if(getNumFlag(x)){ # numerical variable: first change to factor + m = mean(x,na.rm = T); s = sd(x,na.rm = T) + if(sum(!is.na(x)) >= 30 * nlevels(y)){ + splitNow = c(m - s *sqrt(3)/2, m, m + s *sqrt(3)/2) + }else splitNow = c(m - s *sqrt(3)/3, m + s *sqrt(3)/3) + + if(length(unique(splitNow)) == 1) return(0) # No possible split + x = cut(x, breaks = c(-Inf, splitNow, Inf), right = TRUE) + } + if(anyNA(x)){ + levels(x) = c(levels(x), 'newLevel') + x[is.na(x)] <- 'newLevel' + } + if(length(unique(x)) == 1) return(0) # No possible split + + fit <- suppressWarnings(chisq.test(x, y)) + + #> Change to 1-df wilson_hilferty chi-squared stat unless + #> the original df = 1 and p-value is larger than 10^(-16) + ans = unname(ifelse(fit$parameter > 1L, ifelse(fit$p.value > 10^(-16), + qchisq(1-fit$p.value, df = 1), + wilson_hilferty(fit$statistic,fit$parameter)), fit$statistic)) + return(ans) +} + +wilson_hilferty = function(chi, df){ # change df = K to df = 1 + ans = max(0, (7/9 + sqrt(df) * ( (chi / df) ^ (1/3) - 1 + 2 / (9 * df) ))^3) + return(ans) +} + + +# svd helper --------------------------------------------------------------- + + +saferSVD <- function(x, ...){ + #> Target for error code 1 from Lapack routine 'dgesdd' non-convergence error + #> Current solution: Round the design matrix to make approximations, + #> hopefully this will solve the problem + #> + #> The code is a little lengthy, since the variable assignment in tryCatch is tricky + parList <- list(svdObject = NULL, + svdSuccess = FALSE, + errorDigits = 16, + x = x) + while (!parList$svdSuccess) { + parList <- tryCatch({ + parList$svdObject <- svd(parList$x, ...) + parList$svdSuccess <- TRUE + parList + }, error = function(e) { + if (grepl("error code 1 from Lapack routine 'dgesdd'", e$message)) { + parList$x <- round(x, digits = parList$errorDigits) + parList$errorDigits <- parList$errorDigits - 1 + return(parList) + } else stop(e) + }) + } + return(parList$svdObject) +} diff --git a/R/ldaGSVD.R b/R/ldaGSVD.R deleted file mode 100644 index 8cdb699..0000000 --- a/R/ldaGSVD.R +++ /dev/null @@ -1,365 +0,0 @@ -#' Linear Discriminant Analysis using the Generalized Singular Value -#' Decomposition -#' -#' @description `r lifecycle::badge('experimental')` Fit an LDA/GSVD model. -#' -#' @details -#' -#' Traditional Fisher's Linear Discriminant Analysis (LDA) ceases to work when -#' the within-class scatter matrix is singular. The Generalized Singular Value -#' Decomposition (GSVD) is used to address this issue. GSVD simultaneously -#' diagonalizes both the within-class and between-class scatter matrices without -#' the need to invert a singular matrix. This method is believed to be more -#' accurate than PCA-LDA (as in `MASS::lda`) because it also considers the -#' information in the between-class scatter matrix. -#' -#' @param data a data frame that contains both predictors and the response. -#' Missing values are NOT allowed. -#' @param method default to be all -#' -#' @returns An object of class `ldaGSVD` containing the following components: -#' * `scaling`: a matrix which transforms the training data to LD scores, normalized so that the within-group scatter matrix is proportional to the identity matrix. -#' * `formula`: the formula passed to the [ldaGSVD()] -#' * `terms`: a object of class `terms` derived using the input `formula` and the training data -#' * `prior`: a `table` of the estimated prior probabilities. -#' * `groupMeans`: a matrix that records the group means of the training data on the transformed LD scores. -#' * `xlevels`: a list records the levels of the factor predictors, derived using the input `formula` and the training data -#' -#' @export -#' -#' @references Ye, J., Janardan, R., Park, C. H., & Park, H. (2004). \emph{An -#' optimization criterion for generalized discriminant analysis on -#' undersampled problems}. IEEE Transactions on Pattern Analysis and Machine -#' Intelligence -#' -#' Howland, P., Jeon, M., & Park, H. (2003). \emph{Structure preserving dimension -#' reduction for clustered text data based on the generalized singular value -#' decomposition}. SIAM Journal on Matrix Analysis and Applications -#' -#' @examples -#' fit <- ldaGSVD(Species~., data = iris) -#' # prediction -#' predict(fit,iris) -ldaGSVD <- function(datX, - response, - method = c("all", "step"), - fixNA = TRUE, - missingMethod = c("medianFlag", "newLevel"), - prior = NULL, - misClassCost = NULL, - insideTree = FALSE){ - - - # Pre-processing: Arguments and response ---------------------------------- - - method <- match.arg(method, c("all", "step")) - missingMethod <- c(match.arg(missingMethod[1], c("mean", "median", "meanFlag", "medianFlag")), - match.arg(missingMethod[2], c("mode", "modeFlag", "newLevel"))) - stopifnot(!anyNA(response)) # No NAs in the response variable - response <- droplevels(as.factor(response)) # some levels are branched out - - #> Get the prior - if(insideTree){ - prior <- getFinalPrior(prior = prior, response = response) - }else prior <- checkPriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response, internal = FALSE) - - - # Pre-processing: Variables ----------------------------------------------- - - #> Variable Selection Step: for stepwise LDA only - if(method == "step"){ - chiStat <- getChiSqStat(datX = datX, y = response) - idxKeep <- which(chiStat >= qchisq(1 - 0.05/length(chiStat), 1)) # Bonferroni - if(length(idxKeep) == 0) idxKeep <- seq_len(length(chiStat)) - datX <- datX[, idxKeep, drop = FALSE] - } - - if(fixNA){ - imputedSummary <- missingFix(data = datX, missingMethod = missingMethod) - if(anyNA(datX)) datX <- imputedSummary$data - } - - modelFrame <- model.frame(formula = ~.-1, datX, na.action = "na.fail") - Terms <- terms(modelFrame) - m <- scale(model.matrix(Terms, modelFrame)) # constant cols would be changed to NaN in this step - cnames <- colnames(m) - currentVarList <- as.vector(which(apply(m, 2, function(x) !any(is.nan(x))))) # remove constant columns and intercept - - - if(length(currentVarList) == 0) stop("All variables are constant.") - - if(method == "step"){ - #> Output: currentVarList, which contains indices of the selected variables - #> RESPECTIVELY in the design matrix, some columns of m might be removed - - stepRes <- stepVarSelByF(m = m, response = response, currentCandidates = currentVarList) - - #> When no variable is selected, use the full model - #> it might be more time-consuming, but it is better for future LDA split - - if(length(stepRes$currentVarList) != 0){ - currentVarList <- stepRes$currentVarList - - #> modify the design matrix to make it more compact - #> so that only the selected variables are included in the design matrix, - #> and eventually make the prediction faster - selectedVarRawIdx <- unique(sort(attributes(m)$assign[currentVarList])) # MUST be from the modelFrame where the factors are not dummied - modelFrame <- model.frame(formula = ~.-1, datX[, selectedVarRawIdx, drop = FALSE], na.action = "na.fail") - Terms <- terms(modelFrame) - m <- scale(model.matrix(Terms, modelFrame)) - - #> select CERTAIN levels of the factor variables, not ALL - currentVarList <- which(colnames(m) %in% stepRes$stepInfo$var) - } - } - - varSD <- attr(m,"scaled:scale")[currentVarList] - varCenter <- attr(m,"scaled:center")[currentVarList] - m <- m[,currentVarList, drop = FALSE] - - # Step 1: SVD on the combined matrix H - groupMeans <- tapply(c(m), list(rep(response, dim(m)[2]), col(m)), function(x) mean(x, na.rm = TRUE)) - Hb <- sqrt(tabulate(response)) * groupMeans # grandMean = 0 if scaled - - Hw <- m - groupMeans[response, , drop = FALSE] - # if(diff(dim(m)) < 0){ # More rows than columns - # qrRes <- qrEigen(Hw) - # fitSVD <- svdEigen(rbind(Hb, qrRes$R)) - # }else fitSVD <- svdEigen(rbind(Hb, Hw)) - - fitSVD <- saferSVD(rbind(Hb, m - groupMeans[response, , drop = FALSE])) - rankT <- sum(fitSVD$d >= max(dim(fitSVD$u),dim(fitSVD$v)) * .Machine$double.eps * fitSVD$d[1]) - - # Step 2: SVD on the P matrix - #> The code below can be changed to saferSVD if necessary - fitSVDp <- saferSVD(fitSVD$u[seq_len(nlevels(response)), seq_len(rankT), drop = FALSE], nu = 0L) - # fitSVDp <- svdEigen(fitSVD$u[seq_len(nlevels(response)), seq_len(rankT), drop = FALSE]) - rankAll <- min(nlevels(response)-1, rankT) # This is not optimal, but rank(Hb) takes time - - # Fix the variance part - unitSD <- diag(sqrt((length(response) - nlevels(response)) / abs(1 - fitSVDp$d^2 + 1e-15)), nrow = rankAll) # Scale to unit var - scalingFinal <- (fitSVD$v[,seq_len(rankT), drop = FALSE] %*% diag(1 / fitSVD$d[seq_len(rankT)], nrow = rankT) %*% fitSVDp$v[,seq_len(rankAll), drop = FALSE]) %*% unitSD - rownames(scalingFinal) <- cnames[currentVarList] - - groupMeans <- groupMeans %*% scalingFinal - rownames(groupMeans) <- levels(response) - colnames(groupMeans) <- colnames(scalingFinal) <- paste("LD", seq_len(ncol(groupMeans)), sep = "") - - # Get the test statistics and related p-value - statPillai <- sum(fitSVDp$d[seq_len(rankAll)]^2) - #> s & p are changed here, since sometimes design matrix is not of full rank p - p <- rankT; J <- nlevels(response); N <- nrow(m) - s <- rankAll; numF <- N-J-p+s; denF <- abs(p-J+1)+s - - #> When numF is non-positive, Pillai = s & training accuracy = 100% - #> since there always exist a dimension where we can separate every class perfectly - # pValue <- pf(numF / denF * statPillai / (s - statPillai), df1 = s*denF, df2 = s*numF, lower.tail = F) # the same answer - pValue <- ifelse(numF > 0, pbeta(1 - statPillai / s, shape1 = numF * s / 2, shape2 = denF * s / 2), 0) - if(method == "step") pValue <- pValue * length(cnames) # Bonferroni correction - - res <- list(scaling = scalingFinal, terms = Terms, prior = prior, - groupMeans = groupMeans, xlevels = .getXlevels(Terms, modelFrame), - varIdx = currentVarList, varSD = varSD, varCenter = varCenter, statPillai = statPillai, - pValue = pValue) - if(fixNA) res$misReference <- imputedSummary$ref - if(method == "step"){ - res$stepInfo = stepRes$stepInfo - res$stopFlag <- stepRes$stopFlag - } - class(res) <- "ldaGSVD" - - # for LDA splitting - currentP <- unname(table(predict(res, datX)) / dim(datX)[1]) - res$predGini <- 1 - sum(currentP^2) - return(res) -} - - - -#' Predictions from a fitted ldaGSVD object -#' -#' Prediction of test data using a fitted ldaGSVD object -#' -#' Unlike the original paper, which uses the k-nearest neighbor (k-NN) as the -#' classifier, we use a faster and more straightforward likelihood-based method. -#' One limitation of the traditional likelihood-based method for LDA is that it -#' ceases to work when there are Linear Discriminant (LD) directions with zero -#' variance in the within-class scatter matrix. However, when using LDA/GSVD, -#' all chosen LD directions possess non-zero variance in the between-class -#' scatter matrix. This implies that LD directions with zero variance in the -#' within-class scatter matrix will yield the highest Fisher's ratio. Therefore, -#' to get these directions higher weights, we manually adjust the zero variance -#' to `1e-15` for computational reasons. -#' -#' @param object a fitted model object of class `ldaGSVD`, which is assumed to -#' be the result of the [ldaGSVD()] function. -#' @param newdata data frame containing the values at which predictions are -#' required. Missing values are NOT allowed. -#' @param type character string denoting the type of predicted value returned. -#' The default is to return the predicted class (`type` = 'response'). The -#' predicted posterior probabilities for each class will be returned if `type` -#' = 'prob'. -#' @param ... further arguments passed to or from other methods. -#' -#' @return The function returns different values based on the `type`, if -#' * `type = 'response'`: vector of predicted responses. -#' * `type = 'prob'`: a data frame of the posterior probabilities. Each class takes a column. -#' @export -#' -#' @references Ye, J., Janardan, R., Park, C. H., & Park, H. (2004). \emph{An -#' optimization criterion for generalized discriminant analysis on -#' undersampled problems}. IEEE Transactions on Pattern Analysis and Machine -#' Intelligence -#' -#' Howland, P., Jeon, M., & Park, H. (2003). \emph{Structure preserving dimension -#' reduction for clustered text data based on the generalized singular value -#' decomposition}. SIAM Journal on Matrix Analysis and Applications -#' -#' @examples -#' fit <- ldaGSVD(Species~., data = iris) -#' predict(fit,iris) -#' # output prosterior probabilities -#' predict(fit,iris,type = "prob") -predict.ldaGSVD <- function(object, newdata, type = c("response", "prob"), ...){ - type <- match.arg(type, c("response", "prob")) - # add one extra check for levels of the predictors - LDscores <- getLDscores(modelLDA = object, data = newdata) - loglikelihood <- LDscores %*% t(object$groupMeans) + matrix(log(object$prior) - 0.5 * rowSums(object$groupMeans^2), nrow(LDscores), length(object$prior), byrow = TRUE) - # Computation Optimization 2: Prevent a very large likelihood due to exponential - likelihood <- exp(loglikelihood - apply(loglikelihood, 1, max)) - posterior <- likelihood / apply(likelihood, 1, sum) - if(type == "prob") return(posterior) - return(rownames(object$groupMeans)[max.col(posterior, ties.method = "first")]) -} - - -#' @export -print.ldaGSVD <- function(x, ...){ - cat("\nObserved proportions of groups:\n") - print(x$prior) - cat("\n\nGroup means of LD scores:\n") - print(x$groupMeans) - cat("\n\nScaling (coefficients) of LD scores:\n") - print(x$scaling) - invisible(x) -} - - -# helper functions -------------------------------------------------------- - -getPillai <- function(Sw, St){ - #> return -1 if the St is not full rank - #> sometimes diag has negative value due to instability of R - #> We ignore those for now - tryCatch({ - sum(diag(solve(St, St - Sw))) - }, error = function(e) {-1}) -} - -selectVar <- function(currentVar, newVar, Sw, St, direction = "forward"){ - #> return the column index - #> return 0 if all var makes St = 0 - if(direction == "forward"){ - lambdaAll <- sapply(newVar, function(i) getPillai(Sw[c(i,currentVar),c(i,currentVar), drop = FALSE], St[c(i,currentVar),c(i,currentVar), drop = FALSE])) - }else{ - lambdaAll <- sapply(currentVar, function(i) getPillai(Sw[setdiff(currentVar, i),setdiff(currentVar, i), drop = FALSE], St[setdiff(currentVar, i),setdiff(currentVar, i), drop = FALSE])) - } - maxVarIdx <- which(lambdaAll == max(lambdaAll)) # all variables that achieve maximum - currentVarIdx <- maxVarIdx[1] - - return(list(stopflag = (lambdaAll[currentVarIdx] == -1), - varIdx = newVar[currentVarIdx], - statistics = lambdaAll[currentVarIdx], - maxVarIdx = newVar[maxVarIdx])) -} - - -stepVarSelByF <- function(m, response, currentCandidates){ - idxOriginal <- currentCandidates - m <- m[,currentCandidates, drop = FALSE] # all columns should be useful - - groupMeans <- tapply(c(m), list(rep(response, dim(m)[2]), col(m)), function(x) mean(x, na.rm = TRUE)) - mW <- m - groupMeans[response, , drop = FALSE] - - # Initialize - n = nrow(m); g = nlevels(response); p = 0; currentVarList = c() - previousPillai <- previousDiff <- previousDiffDiff <- numeric(ncol(m)+1); - previousDiff[1] <- Inf; diffChecker <- 0 - kRes <- 1; currentCandidates <- seq_len(ncol(m)) - Sw <- St <- matrix(NA, nrow = ncol(m), ncol = ncol(m)) - diag(Sw) <- apply(mW^2,2,sum); diag(St) <- apply(m^2,2,sum) - stopFlag <- 0 - - # Empirical: Calculate the threshold for pillaiToEnter - pillaiThreshold <- 1 / (1 + (n-g) / (abs(g-2)+1) / qf(1 - 0.1 / (ncol(m)+1), abs(g-2)+1, n-g)) / currentCandidates^(0.25) - - stepInfo <- data.frame(var = character(2*ncol(m)), - pillaiToEnter = 0, - threhold = pillaiThreshold, - pillaiToRemove = 0, - pillai = 0) - - #> If n <= g, which means there are too few observations, - #> we output all columns, and leave that problem to outside function - if(anyNA(pillaiThreshold)){ - currentVarList <- currentCandidates; currentCandidates <- c() - stepInfo$var[seq_along(currentVarList)] <- colnames(m)[currentVarList] - stopFlag <- 4 - } - - # Stepwise selection starts! - while(length(currentCandidates) != 0){ - - nCandidates <- length(currentCandidates) - p = p + 1 - selectVarInfo <- selectVar(currentVar = currentVarList, - newVar = currentCandidates, - Sw = Sw, - St = St) - bestVar <- selectVarInfo$varIdx - if(selectVarInfo$stopflag){ # If St = 0, stop. [Might never happens, since there are other variables to choose] - stopFlag <- 1 - break - } - - # get the difference in Pillai's trace - previousDiff[p+1] <- selectVarInfo$statistics - previousPillai[p] - previousDiffDiff[p+1] <- previousDiff[p+1] - previousDiff[p] - diffChecker <- ifelse(abs(previousDiffDiff[p+1]) < 0.001, diffChecker + 1, 0) - - if(previousDiff[p+1] > 10 * previousDiff[p]){ # Correlated variable(s) is included - currentCandidates <- setdiff(currentCandidates, selectVarInfo$maxVarIdx) - p <- p - 1; next - } - - # Check the stopping rule - if(previousDiff[p+1] < pillaiThreshold[p]){ # If no significant variable selected, stop - stopFlag <- 2 - break - } - if(diffChecker == 10){ # converge - stopFlag <- 3 - break - } - - # Add the variable into the model - previousPillai[p+1] <- selectVarInfo$statistics - currentVarList <- c(currentVarList, bestVar) - currentCandidates <- setdiff(currentCandidates, bestVar) - stepInfo$var[kRes] <- colnames(m)[bestVar] - stepInfo$pillaiToEnter[kRes] <- previousDiff[p+1] - stepInfo$pillai[kRes] <- previousPillai[p+1] - kRes <- kRes + 1 - - # Update the Sw and St on the new added column - Sw[currentCandidates, bestVar] <- Sw[bestVar, currentCandidates] <- as.vector(t(mW[, currentCandidates, drop = FALSE]) %*% mW[,bestVar, drop = FALSE]) - St[currentCandidates, bestVar] <- St[bestVar, currentCandidates] <- as.vector(t(m[, currentCandidates, drop = FALSE]) %*% m[,bestVar, drop = FALSE]) - } - - # Remove the empty rows in the stepInfo if stepLDA does not select all variables - stepInfo <- stepInfo[seq_along(currentVarList),] - - return(list(currentVarList = idxOriginal[currentVarList], stepInfo = stepInfo, stopFlag = stopFlag)) -} - - diff --git a/R/new_SingleTreee.R b/R/new_SingleTreee.R index c2e4526..ccd914c 100644 --- a/R/new_SingleTreee.R +++ b/R/new_SingleTreee.R @@ -1,16 +1,25 @@ +#' Create a New Decision Tree +#' +#' This function builds a new decision tree based on input data and a variety of +#' parameters, including LDA type, node model, and thresholds for splitting. The +#' tree is grown recursively by splitting nodes, and child nodes are added until +#' a stopping condition is met. +#' +#' @noRd new_SingleTreee <- function(datX, response, ldaType, nodeModel, - missingMethod, - prior, maxTreeLevel, minNodeSize, pThreshold, - verbose, - kSample){ + prior, + misClassCost, + missingMethod, + kSample, + verbose){ - treeeList = structure(list(), class = "SingleTreee") # save the tree + treeeList = structure(list(), class = "Treee") # save the tree ### Initialize the first Node ### @@ -21,19 +30,21 @@ new_SingleTreee <- function(datX, idxRow = seq_len(nrow(datX)), ldaType = ldaType, nodeModel = nodeModel, - missingMethod = missingMethod, - prior = prior, maxTreeLevel = maxTreeLevel, minNodeSize = minNodeSize, + prior = prior, + misClassCost = misClassCost, + missingMethod = missingMethod, + kSample = kSample, currentLevel = 0, - parentIndex = 0, - kSample = kSample) + parentIndex = 0) + while(length(nodeStack) != 0){ currentIdx <- nodeStack[1]; nodeStack <- nodeStack[-1] # pop the first element if(verbose) cat("The current index is:", currentIdx, "\n") - if(treeeList[[currentIdx]]$stopFlag == 0){ # if it has (potential) child nodes + if(treeeList[[currentIdx]]$stopInfo == "Normal"){ # if it has (potential) child nodes trainIndex <- attr(treeeList[[currentIdx]]$splitFun, "splitRes") # distribute the training set @@ -46,31 +57,31 @@ new_SingleTreee <- function(datX, idxRow = treeeList[[currentIdx]]$idxRow[trainIndex[[i]]], ldaType = ldaType, nodeModel = nodeModel, - missingMethod = missingMethod, - prior = prior, maxTreeLevel = maxTreeLevel, minNodeSize = minNodeSize, + prior = prior, + misClassCost = misClassCost, + missingMethod = missingMethod, + kSample = kSample, currentLevel = treeeList[[currentIdx]]$currentLevel + 1, - parentIndex = currentIdx, - kSample = kSample)) + parentIndex = currentIdx)) + ### Stopping check ### - ### Stopping & pruning ### #> 1. update the p-value for loss drop lossBefore <- treeeList[[currentIdx]]$currentLoss - lossAfter <- do.call(sum,lapply(childNodes, function(node) node$currentLoss)) + lossAfter <- do.call(sum, lapply(childNodes, function(node) node$currentLoss)) treeeList[[currentIdx]]$alpha <- getOneSidedPvalue(N = length(treeeList[[currentIdx]]$idxRow), lossBefore = lossBefore, lossAfter = lossAfter) #> 2. pre-stopping if(treeeList[[currentIdx]]$alpha >= pThreshold){ - treeeList[[currentIdx]]$stopFlag = 6 + treeeList[[currentIdx]]$stopInfo = "Split is not significant" next } - ### Put child nodes in the tree ### - + #> 3. Put child nodes in the tree childIdx <- seq_along(childNodes) + length(treeeList) treeeList[[currentIdx]]$children <- childIdx nodeStack <- c(nodeStack, childIdx) @@ -79,6 +90,6 @@ new_SingleTreee <- function(datX, } for(i in seq_along(treeeList)) treeeList[[i]]$currentIndex <- i # assign the currentIndex - + treeeList <- updateAlphaInTree(treeeList) return(treeeList) } diff --git a/R/new_TreeeNode.R b/R/new_TreeeNode.R index 2c8e628..4b35931 100644 --- a/R/new_TreeeNode.R +++ b/R/new_TreeeNode.R @@ -1,107 +1,97 @@ +#' Create a New Tree Node in the Decision Tree +#' +#' This function creates a new node for the decision tree by fitting a model +#' (such as ULDA or a mode model) based on the data at the current node. It +#' checks for stopping conditions, fits the model, and generates splits if +#' necessary. +#' +#' @noRd new_TreeeNode <- function(datX, response, idxCol, idxRow, ldaType, nodeModel, - missingMethod, - prior, maxTreeLevel, minNodeSize, + prior, + misClassCost, + missingMethod, + kSample, currentLevel, - parentIndex, - kSample) { - + parentIndex) { # Data Cleaning ----------------------------------------------------------- #> Remove empty levels due to partition xCurrent <- droplevels(datX[idxRow, idxCol, drop = FALSE]) responseCurrent <- droplevels(response[idxRow]) - - #> Fix the missing values - imputedSummary <- missingFix(data = xCurrent, missingMethod = missingMethod) - xCurrent <- imputedSummary$data - - #> NOTICE: The missingRef should not be subset after constant check, since there - #> are cases when the original X are constant after imputation, but its flag is important - idxCurrColKeep <- constantColCheck(data = xCurrent) - xCurrent <- xCurrent[, idxCurrColKeep, drop = FALSE] - + priorAndMisClassCost <- updatePriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = responseCurrent, insideNode = TRUE) + prior <- priorAndMisClassCost$prior; misClassCost <- priorAndMisClassCost$misClassCost # Model Fitting ----------------------------------------------------------- - #> check stopping - stopFlag <- stopCheck(responseCurrent = responseCurrent, + stopInfo <- stopCheck(responseCurrent = responseCurrent, numCol = ncol(xCurrent), maxTreeLevel = maxTreeLevel, minNodeSize = minNodeSize, - currentLevel = currentLevel) # # 0/1/2: Normal/Stop+Mode/Stop+LDA + currentLevel = currentLevel) # Normal/Stop+Mode/Stop+ULDA - #> Based on the node model, decide whether we should fit LDA - if(nodeModel == "LDA" | stopFlag == 0){ # LDA model, or mode model with LDA splits - if(stopFlag == 1){ # when LDA can not be fitted + #> Based on the node model, decide whether we should fit ULDA + if(nodeModel == "ULDA" | stopInfo == "Normal"){ # ULDA model, or mode model with ULDA splits + if(stopInfo == "Insufficient data"){ # when LDA can not be fitted nodeModel <- "mode" } else{ - #> Empty response level can not be dropped if prior exists - # splitLDA <- nodePredict <- ldaGSVD(datX = xCurrent, response = responseCurrent, method = ldaType, fixNA = FALSE, prior = prior, insideTree = TRUE) - samplingRows <- sampleForLDA(response = responseCurrent, prior = prior, K = kSample) - splitLDA <- nodePredict <- ldaGSVD(datX = xCurrent[samplingRows$idxFinal,, drop = FALSE], - response = responseCurrent[samplingRows$idxFinal], - method = ldaType, - fixNA = FALSE, - prior = samplingRows$prior, - insideTree = TRUE) + splitLDA <- nodePredict <- folda::folda(datX = xCurrent, + response = responseCurrent, + subsetMethod = ldaType, + prior = prior, + misClassCost = misClassCost, + missingMethod = missingMethod, + downSampling = (kSample != -1), + kSample = kSample) resubPredict <- predict(object = nodePredict, newdata = xCurrent) + currentLoss = sum(resubPredict != responseCurrent) # save the currentLoss for future accuracy calculation + #> if not as good as mode, change it to mode, + #> but the splitting goes on, since the next split might be better. + if(currentLoss >= length(responseCurrent) - max(table(responseCurrent))) nodeModel <- "mode" } } if(nodeModel == "mode"){ nodePredict <- getMode(responseCurrent, prior = prior) resubPredict <- rep(nodePredict, length(responseCurrent)) - } - currentLoss = sum(resubPredict != responseCurrent) # save the currentLoss for future accuracy calculation - - #> if not as good as mode, change it to mode, - #> but the splitting goes on, since the next split might be better. - #> The code is subjective to change if prior will be added - if(currentLoss >= length(responseCurrent) - max(unname(table(responseCurrent)))){ - nodeModel <- "mode" - nodePredict <- getMode(responseCurrent, prior = prior) - resubPredict <- rep(nodePredict, length(responseCurrent)) currentLoss = sum(resubPredict != responseCurrent) } - # Splits Generating ----------------------------------------------------------- - #> Generate the splits - if(stopFlag == 0){ # if splitting goes on, find the splits + if(stopInfo == "Normal"){ # if splitting goes on, find the splits splitFun <- getSplitFunLDA(datX = xCurrent, - response = responseCurrent, - modelLDA = splitLDA) - if(is.null(splitFun)) stopFlag <- 4 # no splits + modelULDA = splitLDA) + if(is.null(splitFun)) stopInfo <- "No feasible splits" } else splitFun <- NULL - # Final Results ----------------------------------------------------------- currentTreeeNode <- list( - # currentIndex = currentIndex, # will be updated in new_SingleTreee() currentLevel = currentLevel, idxCol = idxCol, idxRow = idxRow, currentLoss = currentLoss, # this loss should account for sample size accuracy = 1 - currentLoss / length(responseCurrent), - stopFlag = stopFlag, + stopInfo = stopInfo, proportions = table(responseCurrent, dnn = NULL), # remove the name of the table + prior = prior, + misClassCost = misClassCost, parent = parentIndex, children = c(), # is.null to check terminal nodes - misReference = imputedSummary$ref, splitFun = splitFun, # save the splitting rules - # alpha = NA, # p-value from t-test, to measure the split's strength for model selection nodeModel = nodeModel, nodePredict = nodePredict # predict Function + # currentIndex = currentIndex, # will be updated in new_SingleTreee() + # alpha = NA, # p-value from t-test, to measure the split's strength for model selection + # pruned = NULL # generated during pruning ) class(currentTreeeNode) <- "TreeeNode" # Set the name for the class return(currentTreeeNode) diff --git a/R/plot.R b/R/plot.R index 3d8e8f6..a87a49d 100644 --- a/R/plot.R +++ b/R/plot.R @@ -1,106 +1,104 @@ -#' Plot a Treee object +#' Plot a Treee Object #' -#' Provide a diagram of the whole tree structure or a scatter/density plot for a -#' specific tree node. +#' This function visualizes either the entire decision tree or a specific node +#' within the tree. The tree is displayed as an interactive network of nodes and +#' edges, while individual nodes are scatter/density plots using `ggplot2`. #' -#' @section Overall tree structure: +#' @section Overall Tree Structure: #' -#' A full tree diagram (via the R package [visNetwork]) is shown if `node` is -#' not provided (default is `-1`). The color shows the most common (plurality) -#' class inside each node. The size of each terminal node is based on its -#' relative sample size. Under every node, you see the plurality class, the -#' fraction of the correctly predicted training sample vs. the node's sample -#' size, and the node index, respectively. When you click on the node, an -#' information panel with more details will appear. +#' A full tree diagram is displayed using [visNetwork] when `node` is not +#' specified (the default is `-1`). The color represents the most common +#' (plurality) class within each node, and the size of each terminal node +#' reflects its relative sample size. Below each node, the fraction of +#' correctly predicted training samples and the total sample size for that +#' node are shown, along with the node index. Clicking on a node opens an +#' information panel with additional details. #' -#' @section Individual plot for each node: +#' @section Individual Node Plot: #' -#' The node index and the original training data are required to return a more -#' detailed plot within a specific node. The density plot will be provided -#' when only two levels are left for the response variable in a node (like in -#' a binary classification problem). Samples are projected down to their first -#' linear discriminant scores (LD1). A scatter plot will be provided if a node -#' contains more than two classes. Samples are projected down to their first -#' and second linear discriminant scores. +#' To plot a specific node, you must provide the node index along with the +#' original training predictors (`datX`) and responses (`response`). A scatter +#' plot is generated if more than one discriminant score is available, +#' otherwise, a density plot is created. Samples are projected onto their +#' linear discriminant score(s). #' -#' @param x a fitted model object of class `Treee`, which is assumed to be the -#' result of the [Treee()] function. -#' @param data the original data you used to fit the `Treee` object if you want -#' the individual plot for each node. Otherwise, you can leave this parameter -#' blank if you only need the overall tree structure diagram. -#' @param node the node index that you are interested in. By default, it is set -#' to `-1` and the overall tree structure is drawn. -#' @param ... further arguments passed to or from other methods. -#' -#' @returns For overall tree structure (`node = -1`), A figure of class -#' `visNetwork` is drawn. Otherwise, a figure of class `ggplot` is drawn. +#' @param x A fitted model object of class `Treee`, typically the result of the +#' [Treee()] function. +#' @param datX A data frame of predictor variables. Required for plotting +#' individual nodes. +#' @param response A vector of response values. Required for plotting individual +#' nodes. +#' @param node An integer specifying the node to plot. If `node = -1`, the +#' entire tree is plotted. Default is `-1`. +#' @param ... Additional arguments passed to the plotting functions. #' +#' @return A `visNetwork` interactive plot of the decision tree if `node = -1`, +#' or a `ggplot2` object if a specific node is plotted. #' @export #' #' @examples -#' fit <- Treee(Species~., data = iris) -#' # plot the overall tree -#' plot(fit) -#' # plot a certain node -#' plot(fit, iris, node = 1) -plot.Treee <- function(tree, datX, response, node = -1, ...){ - treeeOutput <- tree - if(node>0){ - if(missing(datX) | missing(response)) stop("Please input the orginal training data for nodewise LDA plots") - if(treeeOutput$treee[[node]]$nodeModel == "mode") return(paste("Every observation in this node is predicted to be", treeeOutput$treee[[node]]$nodePredict)) - # Get the data ready, impute the NAs (if any) - response <- as.factor(response) - newX <- getDataInShape(data = datX[treeeOutput$treee[[node]]$idxRow,,drop = FALSE], missingReference = treeeOutput$treee[[node]]$misReference) - colorIdx <- match(names(treeeOutput$treee[[node]]$proportions), levels(response)) - - plotLDA2d(ldaModel = treeeOutput$treee[[node]]$nodePredict, - data = cbind.data.frame(response = response[treeeOutput$treee[[node]]$idxRow], newX), - node = node, - colorManual = scales::hue_pal()(nlevels(response))[colorIdx]) - }else{ # default overall plot - plot(treeeOutput$treee) - } -} +#' fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) +#' plot(fit) # plot the overall tree +#' plot(fit, datX = iris, response = iris[, 5], node = 1) # plot a specific node +plot.Treee <- function(x, datX, response, node = -1, ...){ + # Save the color manual, since some classes might be empty during branching + colorManual = grDevices::hcl.colors(length(x[[1]]$proportions)) + names(colorManual) <- responseLevels <- names(x[[1]]$proportions) + if(node < 0){ # Overall tree plot + idTransVec <- seq_along(x) + nodes <- do.call(rbind, lapply(x, function(treeeNode) nodesHelper(treeeNode = treeeNode, idTransVec = idTransVec))) + edges <- do.call(rbind, lapply(x, edgesHelper)) + p <- visNetwork::visNetwork(nodes, edges, width = "100%", height = "600px")%>% + visNetwork::visNodes(shape = 'dot')%>% + visNetwork::visHierarchicalLayout(levelSeparation = 100)%>% + visNetwork::visInteraction(dragNodes = FALSE, + dragView = TRUE, + zoomView = TRUE) -#' @export -plot.SingleTreee <- function(x, ...){ - idTransVec <- seq_along(x) - nodes <- do.call(rbind, sapply(x, function(treeeNode) nodesHelper(treeeNode = treeeNode, idTransVec = idTransVec),simplify = FALSE)) - edges <- do.call(rbind, sapply(x, edgesHelper,simplify = FALSE)) - p <- visNetwork::visNetwork(nodes, edges, width = "100%", height = "600px")%>% - visNetwork::visNodes(shape = 'dot')%>% - visNetwork::visHierarchicalLayout(levelSeparation = 100)%>% - visNetwork::visLegend(width = 0.1, position = "right", main = "Group")%>% - visNetwork::visInteraction(dragNodes = FALSE, - dragView = TRUE, - zoomView = TRUE) + for (i in seq_along(responseLevels)) { + p <- p %>% visNetwork::visGroups(groupname = responseLevels[i], color = unname(colorManual[i])) + } - # Change the color manual - colorManual = scales::hue_pal()(length(x[[1]]$proportions)) - for(i in seq_along(colorManual)){ - p <- p %>% visNetwork::visGroups(groupname = names(x[[1]]$proportions)[i], color = colorManual[i]) + legend_nodes <- lapply(seq_along(responseLevels), function(i) { # add legends + list(label = responseLevels[i], + shape = "dot", + color = unname(colorManual[i])) + }) + p <- p %>% visNetwork::visLegend(addNodes = legend_nodes, width = 0.1, useGroups = FALSE, position = "right", main = "Class") + } else{ # individual node plot + if(x[[node]]$nodeModel == "mode") return(paste("Every observation in node", node, "is predicted to be", x[[node]]$nodePredict)) + if(missing(datX) || missing(response)) stop("Please input the training X and Y for the nodewise plot") + colorIdx <- match(names(x[[node]]$proportions), levels(response)) + p <- plot(x = x[[node]]$nodePredict, + datX = datX[x[[node]]$idxRow,,drop = FALSE], + response = response[x[[node]]$idxRow]) + p$scales$scales <- list() # remove old color palette + p <- p + + ggplot2::scale_color_manual(values = colorManual[colorIdx])+ + ggplot2::scale_fill_manual(values = colorManual[colorIdx])+ + ggplot2::labs(caption = paste("Node", node)) } - return(p) } + infoClickSingle <- function(treeeNode, idTransVec){ - line1 = '#### Information Panel ####' - line2 = paste('
Current Node Index:', idTransVec[treeeNode$currentIndex]) - line3 = paste('
There are', length(treeeNode$idxRow), 'data in this node') - # line4 = paste('
The proportion of', paste(names(treeeNode$proportions), collapse = ', '),'are', - # paste(sprintf("%.1f%%", treeeNode$proportions / length(treeeNode$idxRow) * 100), collapse = ', ')) - # line4 = paste('
', length(treeeNode$idxRow) - treeeNode$currentLoss, 'of them are correctly classified') - line5 = paste('
The resubstitution acc is ', round(treeeNode$accuracy,3)) - line5.5 = paste('
Plurality class (', round(max(treeeNode$proportions) / sum(treeeNode$proportions),4)*100, '%) is ', names(sort(treeeNode$proportions, decreasing = TRUE))[1], sep = "") - line6 = paste('
The model in this node is ', treeeNode$nodeModel) - line6.3 = paste("
Pillai's trace is ", tryCatch({treeeNode$nodePredict$statPillai}, error = function(e) {NULL})) - line6.5 = paste('
p value is ', tryCatch({treeeNode$nodePredict$pValue}, error = function(e) {NULL})) - line6.6 = paste('
predGini is ', tryCatch({treeeNode$nodePredict$predGini}, error = function(e) {NULL})) - line6.7 = paste('
alpha is ', tryCatch({treeeNode$alpha}, error = function(e) {NULL})) - line7 = paste('
stopFlag is ', treeeNode$stopFlag) - return(paste(line1,line2,line3,line5,line5.5,line6,line6.3,line6.5,line6.6,line6.7,line7)) + line1 = paste('#### Information Panel: Node', idTransVec[treeeNode$currentIndex], '####') + line2 = paste('
There are', length(treeeNode$idxRow), 'data in this node') + line3 = paste('
The resubstitution acc is ', round(treeeNode$accuracy,3)) + line4 = paste('
Plurality class (', round(max(treeeNode$proportions) / sum(treeeNode$proportions),4)*100, '%) is ', names(sort(treeeNode$proportions, decreasing = TRUE))[1], sep = "") + line5 = paste('
The model in this node is ', treeeNode$nodeModel) + line6 = paste('
stopInfo:', treeeNode$stopInfo) + + if (treeeNode$nodeModel != "mode") { + line7 = paste("
Pillai's trace is ", round(treeeNode$nodePredict$statPillai, 3)) + line8 = paste('
MANOVA p value is ', format(treeeNode$nodePredict$pValue, scientific= TRUE, digits = 4)) + line9 = paste('
Gini Index is ', format(treeeNode$nodePredict$predGini, scientific= TRUE, digits = 4)) + line10 = paste('
Splitting p value is ', format(treeeNode$alpha, scientific= TRUE, digits = 4)) + } else line7 = line8 = line9 = line10 = "" + + return(paste(line1,line2,line3,line4,line5,line6,line7,line8,line9,line10)) } @@ -111,13 +109,8 @@ nodesHelper <- function(treeeNode, idTransVec){ value = ifelse(terminalFlag, log(length(treeeNode$idxRow)), 2) # node size level = treeeNode$currentLevel group = names(sort(treeeNode$proportions, decreasing = TRUE))[1] - label = paste(# group, # paste(treeeNode$proportions, collapse = ' / '), - paste(length(treeeNode$idxRow) - treeeNode$currentLoss, length(treeeNode$idxRow), sep = ' / '), - # treeeNode$currentLoss, + label = paste(paste(length(treeeNode$idxRow) - treeeNode$currentLoss, length(treeeNode$idxRow), sep = ' / '), paste('Node', idTransVec[id]),sep = "\n") - # paste('alpha:', treeeNode$alpha) - # paste('Tnodes:', paste(treeeNode$offsprings, collapse = "/")), - # paste('Pruned:', treeeNode$pruned) return(data.frame(id, title, value, level, group, label, shadow = TRUE)) } @@ -128,34 +121,3 @@ edgesHelper <- function(treeeNode){ return(data.frame(from = treeeNode$currentIndex, to = treeeNode$children)) } } - -plotLDA2d <- function(ldaModel, data, node, colorManual){ - LD1 <- LD2 <- response <- NULL # walk around the binding error in R CMD check - # browser() - if(dim(ldaModel$scaling)[2] == 1){ - # Only one LD is available, draw the histogram - datCombined <- cbind.data.frame(response = data$response, LD1 = getLDscores(modelLDA = ldaModel, data = data, nScores = 1)) - estimatedPrior <- table(datCombined$response) / length(datCombined$response) - estimatedPrior <- estimatedPrior[which(estimatedPrior != 0)] # some classes are not available - datPlot <- do.call(rbind, lapply(seq_along(estimatedPrior), function(i) cbind(with(density(datCombined$LD1[datCombined$response == names(estimatedPrior)[i]]), data.frame(LD1 = x, density = y * estimatedPrior[i])), response = names(estimatedPrior)[i]))) - datPlot$response <- factor(datPlot$response, levels = names(estimatedPrior)) - p <- ggplot2::ggplot(data = datPlot)+ - ggplot2::geom_line(ggplot2::aes(x = LD1, y = density, color = response))+ - ggplot2::geom_ribbon(ggplot2::aes(x = LD1, ymin = 0, ymax = density, fill = response), alpha = 0.5)+ - ggplot2::scale_fill_manual(values = colorManual)+ - ggplot2::theme_bw()+ - ggplot2::labs(title = "Density plot of LD1", subtitle = paste("Node:",node)) - }else{ - LDscores <- getLDscores(modelLDA = ldaModel, data = data, nScores = 2) - datPlot <- cbind.data.frame(response = data$response, LDscores) - p <- ggplot2::ggplot(data = datPlot)+ - ggplot2::geom_point(ggplot2::aes(x = LD1, y = LD2, color = response), alpha = 0.7)+ - ggplot2::scale_color_manual(values = colorManual)+ - ggplot2::theme_bw()+ - ggplot2::labs(title = "Scatter plot by first two LDscores", subtitle = paste("Node:",node)) - } - return(p) -} - - - diff --git a/R/predict.R b/R/predict.R index 557d6e2..2b935d0 100644 --- a/R/predict.R +++ b/R/predict.R @@ -1,50 +1,47 @@ -#'Predictions from a fitted Treee object +#' Predictions From a Fitted Treee Object #' -#'Prediction of test data using a fitted Treee object +#' Generate predictions on new data using a fitted `Treee` model. #' -#'@param object a fitted model object of class `Treee`, which is assumed to be -#' the result of the [Treee()] function. -#'@param newdata data frame containing the values at which predictions are -#' required. Missing values are allowed. -#'@param type character string denoting the type of predicted value returned. -#' The default is to return the predicted class (`type = 'response'`). The -#' predicted posterior probabilities for each class will be returned if `type = -#' 'prob'`. `'all'` returns a data frame with predicted classes, posterior -#' probabilities, and the predicted node indices. -#' @param ... further arguments passed to or from other methods. +#' @param object A fitted model object of class `Treee`, typically the result of +#' the [Treee()] function. +#' @param newdata A data frame containing the predictor variables. Missing +#' values are allowed and will be handled according to the fitted tree's +#' method for handling missing data. +#' @param type A character string specifying the type of prediction to return. +#' Options are: +#' * `'response'`: returns the predicted class for each observation (default). +#' * `'prob'`: returns a data frame of posterior probabilities for each class. +#' * `'all'`: returns a data frame containing predicted classes, posterior probabilities, and the predicted node indices. +#' @param ... Additional arguments passed to or from other methods. #' -#'@returns The function returns different values based on the `type`, if -#' * `type = 'response'`: vector of predicted responses. -#' * `type = 'prob'`: a data frame of the posterior probabilities. Each class takes a column. -#' * `type = 'all'`: a data frame contains the predicted responses, posterior probabilities, and the predicted node indices. +#' @return Depending on the value of `type`, the function returns: +#' * If `type = 'response'`: A character vector of predicted class labels. +#' * If `type = 'prob'`: A data frame of posterior probabilities, where each class has its own column. +#' * If `type = 'all'`: A data frame containing predicted class labels, posterior probabilities, and the predicted node indices. #' -#'Note: for factor predictors, if it contains a level which is not used to -#' grow the tree, it will be converted to missing and will be imputed according -#' to the `missingMethod` in the fitted tree. -#'@export +#' Note: For factor predictors, if a level not present in the training data is +#' found in `newdata`, it will be treated as missing and handled according to +#' the `missingMethod` specified in the fitted tree. +#' +#' @export #' #' @examples -#' fit <- Treee(Species~., data = iris) -#' predict(fit,iris) -#' # output prosterior probabilities -#' predict(fit,iris,type = "prob") -predict.Treee <- function(object, newdata, type = c("response", "prob", "all"), insideCV = FALSE, newY = NULL, ...){ - # input type: data.frame / matrix / vector - # if(!inherits(object, "Treee")) stop("object not of class \"Treee\"") - stopifnot(is.data.frame(newdata)) - - type <- match.arg(type, c("response", "prob", "all")) - - return(predict(object$treee, newdata = newdata, type = type, insideCV = insideCV, newY = newY)) -} +#' fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) +#' head(predict(fit, iris)) # Predicted classes +#' head(predict(fit, iris[, -5], type = "prob")) # Posterior probabilities +#' head(predict(fit, iris[, -5], type = "all")) # Full details +predict.Treee <- function(object, newdata, type = c("response", "prob", "all"), ...){ + if (!is.data.frame(newdata)) stop("datX must be a data.frame") + dots <- list(...) + insideCV <- if (!is.null(dots$insideCV)) dots$insideCV else FALSE + obsY <- if (!is.null(dots$obsY)) dots$obsY else NULL -#' @export -predict.SingleTreee <- function(object, newdata, type = "response", insideCV = FALSE, newY = NULL, ...){ + type <- match.arg(type, c("response", "prob", "all")) cname <- names(object[[1]]$proportions) # find the class names res <- data.frame(response = character(nrow(newdata)), node = numeric(nrow(newdata)), - newCols = matrix(0,nrow = nrow(newdata), ncol = length(cname))) + newCols = matrix(0, nrow = nrow(newdata), ncol = length(cname))) colnames(res)[2+seq_along(cname)] <- cname # posertior probs' name nodeList <- vector(mode = "list", length = length(object)) # keep the testing obs index nodeList[[1]] <- seq_len(nrow(newdata)) @@ -58,22 +55,20 @@ predict.SingleTreee <- function(object, newdata, type = "response", insideCV = F if(length(currentObs) == 0) next if(insideCV){ - object[[currentIdx]]$CVerror <- sum(predNode(data = newdata[currentObs,,drop = FALSE], + object[[currentIdx]]$CVerror <- sum(predNode(datX = newdata[currentObs,,drop = FALSE], treeeNode = currentNode, - missingReference = currentNode$misReference, - type = "response") != newY[currentObs]) + type = "response") != obsY[currentObs]) } if(is.null(currentNode$children)){ # terminal nodes res$node[currentObs] <- currentIdx - posteriorProbs <- predNode(data = newdata[currentObs,,drop = FALSE], + posteriorProbs <- predNode(datX = newdata[currentObs,,drop = FALSE], treeeNode = currentNode, - missingReference = currentNode$misReference, type = "prob") - res$response[currentObs] <- colnames(posteriorProbs)[max.col(posteriorProbs, ties.method = "first")] + res$response[currentObs] <- colnames(posteriorProbs)[max.col(-posteriorProbs %*% t(currentNode$misClassCost), ties.method = "first")] res[currentObs, match(colnames(posteriorProbs), colnames(res))] <- posteriorProbs }else{ # internal nodes - trainIndex <- currentNode$splitFun(datX = newdata[currentObs,,drop = FALSE], missingReference = currentNode$misReference) + trainIndex <- currentNode$splitFun(datX = newdata[currentObs,,drop = FALSE]) nodeStack <- c(nodeStack, currentNode$children) for(i in seq_along(currentNode$children)) nodeList[[currentNode$children[i]]] <- currentObs[trainIndex[[i]]] } @@ -86,6 +81,18 @@ predict.SingleTreee <- function(object, newdata, type = "response", insideCV = F } +predNode <- function(datX, treeeNode, type){ + if(treeeNode$nodeModel != "ULDA"){ + if(type == "response"){ + pred <- rep(treeeNode$nodePredict, nrow(datX)) + } else{ # if type = "all", the extra response column will be added later + pred <- matrix(0,nrow = nrow(datX), ncol = length(treeeNode$proportions), dimnames = list(c(), names(treeeNode$proportions))) + pred[,which(treeeNode$nodePredict == colnames(pred))] <- 1 + } + } else pred <- predict(object = treeeNode$nodePredict, newdata = datX, type = type) + return(pred) +} + diff --git a/R/print.R b/R/print.R index eeff490..fd15e1a 100644 --- a/R/print.R +++ b/R/print.R @@ -1,10 +1,3 @@ -#' @export -print.Treee <- function(x, ...){ - print(x$treee) - invisible(x) -} - - #' @export print.TreeeNode <- function(x, ...){ cat(paste0("Node ",x$currentIndex,":\n")) @@ -22,8 +15,6 @@ print.TreeeNode <- function(x, ...){ #' @export summary.TreeeNode <- function(object, ...){ # just print out everything besides some super long info - object$idxRow <- object$idxCol <- object$nodePredict <- object$misReference <- NULL - # class(object) <- "summary.TreeeNode" - # return(object) + object$idxRow <- object$idxCol <- object$nodePredict <- object$splitFun <- NULL return(unclass(object)) } diff --git a/R/prune.R b/R/prune.R index cdd67c2..a34b5e5 100644 --- a/R/prune.R +++ b/R/prune.R @@ -1,28 +1,26 @@ - -# pruning ----------------------------------------------------------------- - +#' Prune a Decision Tree +#' +#' This function performs pruning on an existing decision tree by using +#' cross-validation to assess the best level of tree complexity, balancing the +#' trade-off between tree size and predictive performance. +#' +#' @noRd prune <- function(oldTreee, - numberOfPruning, datX, response, ldaType, nodeModel, - missingMethod, - prior, + numberOfPruning, maxTreeLevel, minNodeSize, pThreshold, - verbose, - kSample){ - - - # Parameter Clean Up ------------------------------------------------------ - - oldTreee <- updateAlphaInTree(oldTreee) # make the alpha monotone - treeeSaved = oldTreee - - # pruning and error estimate ---------------------------------------------- + prior, + misClassCost, + missingMethod, + kSample, + verbose){ + treeeSaved = oldTreee # saved a copy for cutting idxCV <- sample(seq_len(numberOfPruning), length(response), replace = TRUE) treeeForPruning <- vector(mode = "list", length = numberOfPruning) @@ -32,51 +30,46 @@ prune <- function(oldTreee, } for(i in seq_len(numberOfPruning)){ - treeeForPruning[[i]] <- predict(new_SingleTreee(datX = datX[idxCV!=i, , drop = FALSE], - response = response[idxCV!=i], + treeeForPruning[[i]] <- predict(new_SingleTreee(datX = datX[idxCV != i, , drop = FALSE], + response = response[idxCV != i], ldaType = ldaType, nodeModel = nodeModel, - missingMethod = missingMethod, - prior = prior, maxTreeLevel = maxTreeLevel, minNodeSize = minNodeSize, pThreshold = 0.51, - verbose = FALSE, - kSample = kSample), - newdata = datX[idxCV ==i, , drop = FALSE], + prior = prior, + misClassCost = misClassCost, + missingMethod = missingMethod, + kSample = kSample, + verbose = FALSE), + newdata = datX[idxCV == i, , drop = FALSE], insideCV = TRUE, - newY = response[idxCV==i]) + obsY = response[idxCV == i]) if(verbose) utils::setTxtProgressBar(pbGrowTree, i) } - treeeForPruning <- lapply(treeeForPruning, updateAlphaInTree) - - CV_Table <- data.frame() - numOfPruning <- 0 if(verbose){ cat('\nPruning: Prune the trees...\n') pbPruneTree <- utils::txtProgressBar(min = 0, max = length(treeeSaved), style = 3) } + CV_Table <- data.frame(); pruningCounter <- 0 + while(TRUE){ nodesCount <- sum(sapply(oldTreee, function(treeeNode) is.null(treeeNode$pruned))) if(verbose) utils::setTxtProgressBar(pbPruneTree, length(treeeSaved) - nodesCount) meanAndSE <- getMeanAndSE(treeeListList = treeeForPruning) - currentCutAlpha = getCutAlpha(treeeList = oldTreee) - - # summary output - CV_Table <- rbind(CV_Table, c(numOfPruning, nodesCount, meanAndSE, currentCutAlpha)) + CV_Table <- rbind(CV_Table, c(pruningCounter, nodesCount, meanAndSE, currentCutAlpha)) if (nodesCount == 1) { if(verbose) utils::setTxtProgressBar(pbPruneTree, length(treeeSaved)) break } - # Cut the treee oldTreee <- pruneTreee(treeeList = oldTreee, alpha = currentCutAlpha) treeeForPruning <- lapply(treeeForPruning, function(treeeList) pruneTreee(treeeList = treeeList, alpha = currentCutAlpha)) - numOfPruning <- numOfPruning + 1 + pruningCounter <- pruningCounter + 1 } colnames(CV_Table) <- c("treeeNo", "nodeCount", "meanMSE", "seMSE", "alpha") @@ -86,45 +79,51 @@ prune <- function(oldTreee, kSE = 0.25 pruneThreshold <- (CV_Table$meanMSE + kSE * CV_Table$seMSE)[which.min(CV_Table$meanMSE)] idxFinal <- dim(CV_Table)[1] + 1 - which.max(rev(CV_Table$meanMSE <= pruneThreshold)) - for(i in seq_len(idxFinal-1)){ - treeeSaved <- pruneTreee(treeeList = treeeSaved, alpha = CV_Table$alpha[i]) - } - treeeNew <- dropNodes(treeeSaved) + for(i in seq_len(idxFinal-1)) treeeSaved <- pruneTreee(treeeList = treeeSaved, alpha = CV_Table$alpha[i]) - return(list(treeeNew = treeeNew, - CV_Table = CV_Table)) + return(list(treeeNew = dropNodes(treeeSaved), CV_Table = CV_Table)) } + +#' Update Alpha Values in a Tree +#' +#' This function calculates and updates the alpha values for each node in the +#' decision tree. Alpha is a measure (p-value) of the improvement in loss after +#' pruning a subtree. +#' +#' @noRd updateAlphaInTree <- function(treeeList){ - #> Purpose: Calculate alpha, and make it monotonic #> assumption: the index of the children must be larger than its parent's - for(i in rev(seq_along(treeeList))){ if(is.null(treeeList[[i]]$pruned)){ # Only loop over the unpruned nodes ## Get all terminal nodes - treeeList[[i]]$offsprings <- getTerminalNodes(currentIdx = i, treeeList = treeeList) + treeeList[[i]]$childrenTerminal <- getTerminalNodes(currentIdx = i, treeeList = treeeList) + ## Get re-substitution error - treeeList[[i]]$offspringLoss <- sum(sapply(treeeList[[i]]$offsprings, function(idx) treeeList[[idx]]$currentLoss)) - # treeeList[[i]]$alpha <- (treeeList[[i]]$currentLoss - treeeList[[i]]$offspringLoss) / (length(treeeList[[i]]$offsprings) - 1) + treeeList[[i]]$childrenTerminalLoss <- sum(sapply(treeeList[[i]]$childrenTerminal, function(idx) treeeList[[idx]]$currentLoss)) treeeList[[i]]$alpha <- getOneSidedPvalue(N = length(treeeList[[i]]$idxRow), lossBefore = treeeList[[i]]$currentLoss, - lossAfter = treeeList[[i]]$offspringLoss) - ## Update alpha to be monotonic - # childrenAlpha <- sapply(treeeList[[i]]$children, function(idx) treeeList[[idx]]$alpha) - #> In case that the tree is not root, but all alpha are the same - #> we add one to the root node alpha to separate them - # treeeList[[i]]$alpha <- min(c(Inf, treeeList[[i]]$alpha, unlist(childrenAlpha)), na.rm = TRUE) + lossAfter = treeeList[[i]]$childrenTerminalLoss) } } - return(treeeList) } +#' Calculate Mean and Standard Error of Tree Errors +#' +#' This function computes the mean and standard error of the errors from a list +#' of decision trees. +#' +#' @noRd +#' +#' @param treeeListList A list of decision tree objects, where each tree +#' contains information about its error. +#' getMeanAndSE <- function(treeeListList){ error <- sapply(treeeListList, getMeanAndSEhelper) - return(c(mean(error), sd(error) / sqrt(length(treeeListList)))) + return(c(mean(error), stats::sd(error) / sqrt(length(treeeListList)))) } getMeanAndSEhelper <- function(treeeList){ @@ -133,28 +132,39 @@ getMeanAndSEhelper <- function(treeeList){ } +#' Calculate the Cut-off Alpha for Pruning +#' +#' This function calculates the cut-off alpha value for pruning the decision +#' tree. The alpha value is computed from the internal (non-terminal) nodes of +#' the tree, representing the improvement in loss after pruning. The geometric +#' mean of the two largest alpha values is returned as the cut-off. +#' +#' @noRd getCutAlpha <- function(treeeList){ - # get alpha for all non-terminal nodes, NA for terminal nodes - # alphaList <- unique(sapply(treeeList, function(treeeNode) ifelse(is.null(treeeNode$children), NA, treeeNode$alpha))) - - #> only count alpha from internal nodes internalIdx <- setdiff(getTerminalNodes(1, treeeList, keepNonTerminal = T), getTerminalNodes(1, treeeList, keepNonTerminal = F)) alphaList <- unique(do.call(c, lapply(internalIdx, function(i) treeeList[[i]]$alpha))) + alphaCandidates <- stats::na.omit(sort(alphaList, decreasing = TRUE)[1:2]) - #> Geometry average - #> When no alpha is available, use .Machine$double.eps instead - alphaCandidates <- na.omit(sort(alphaList, decreasing = TRUE)[1:2]) if(length(alphaCandidates) == 0) alphaCandidates <- 1 return(exp(mean(log(alphaCandidates)))) } +#' Prune a Decision Tree Based on Alpha +#' +#' This function prunes the decision tree by removing branches whose alpha +#' values are greater than or equal to the given threshold. Pruning is performed +#' on non-terminal nodes that meet the alpha condition, and their child nodes +#' are marked as pruned. After pruning, the alpha values of the remaining nodes +#' are updated. +#' +#' @noRd pruneTreee <- function(treeeList, alpha){ for(i in rev(seq_along(treeeList))){ treeeNode <- treeeList[[i]] # not yet pruned + non-terminal node + alpha below threshold - currentFlag <- is.null(treeeNode$pruned) & !is.null(treeeNode$children) & treeeNode$alpha >= alpha - 1e-10 # R rounding error, 1e-10 needed + currentFlag <- is.null(treeeNode$pruned) && !is.null(treeeNode$children) && treeeNode$alpha >= alpha - 1e-10 # R rounding error, 1e-10 needed if(currentFlag){ allChildren <- getTerminalNodes(currentIdx = i, treeeList = treeeList, keepNonTerminal = TRUE) @@ -167,16 +177,24 @@ pruneTreee <- function(treeeList, alpha){ } +#' Drop Pruned Nodes from a Decision Tree +#' +#' This function removes pruned nodes from the decision tree, keeping only +#' terminal and relevant internal nodes. It reassigns the indices of the +#' remaining nodes and updates their children and parent node references +#' accordingly. +#' +#' @noRd dropNodes <- function(treeeList){ finalNodeIdx <- sort(getTerminalNodes(treeeList = treeeList, currentIdx = 1, keepNonTerminal = TRUE)) treeeList <- treeeList[finalNodeIdx] - class(treeeList) <- "SingleTreee" + class(treeeList) <- "Treee" for(i in seq_along(treeeList)){ treeeList[[i]]$currentIndex <- i # re-assign the currentIndex if(!is.null(treeeList[[i]]$children)) { treeeList[[i]]$children <- sapply(treeeList[[i]]$children, function(x) which(finalNodeIdx == x)) }else{ - if(treeeList[[i]]$stopFlag == 0) treeeList[[i]]$stopFlag = 3 # due to pruning + if(treeeList[[i]]$stopInfo == "Normal") treeeList[[i]]$stopInfo = "Pruned" } treeeList[[i]]$parent <- which(finalNodeIdx == treeeList[[i]]$parent) } @@ -184,6 +202,13 @@ dropNodes <- function(treeeList){ } +#' Retrieve Terminal Nodes in a Decision Tree +#' +#' This function retrieves all terminal (leaf) nodes that are descendants of a +#' specified node in the decision tree. Optionally, intermediate nodes can also +#' be included in the result. +#' +#' @noRd getTerminalNodes <- function(currentIdx, treeeList, keepNonTerminal = FALSE){ #> Get all terminal nodes that are offsprings from currentIdx treeeNode <- treeeList[[currentIdx]] @@ -191,7 +216,7 @@ getTerminalNodes <- function(currentIdx, treeeList, keepNonTerminal = FALSE){ return(currentIdx) }else{ nonTerminalNodes <- if(keepNonTerminal) currentIdx - terminalNodes <- do.call(c, sapply(treeeNode$children,function(x) getTerminalNodes(currentIdx = x, treeeList = treeeList, keepNonTerminal = keepNonTerminal), simplify = FALSE)) + terminalNodes <- do.call(c, lapply(treeeNode$children,function(x) getTerminalNodes(currentIdx = x, treeeList = treeeList, keepNonTerminal = keepNonTerminal))) return(c(nonTerminalNodes, terminalNodes)) } } diff --git a/R/split.R b/R/split.R index 8fac2f0..42eb9a7 100644 --- a/R/split.R +++ b/R/split.R @@ -1,30 +1,27 @@ -# mixed ------------------------------------------------------------------- - - -getSplitFunLDA <- function(datX, response, modelLDA){ - if(modelLDA$predGini <= 0.1) modelLDA$prior[] <- 1 / length(modelLDA$prior) # change to equal prior - return(getSplitFunLDAhelper(datX = datX, - response = response, - modelLDA = modelLDA)) +getSplitFunLDA <- function(datX, modelULDA){ + if(modelULDA$predGini <= 0.1) modelULDA$prior[] <- 1 / length(modelULDA$prior) # change to equal prior + return(getSplitFunLDAhelper(datX = datX, modelULDA = modelULDA)) } -# LDA -------------------------------------------------------------------- - - -getSplitFunLDAhelper <- function(datX, response, modelLDA){ - #> This function is called only when building the tree - - #> If there are some classes not being predicted - #> we will assign them to the class with the second largest posterior prob - predictedOutcome <- predict(modelLDA, datX) - # if(length(unique(predictedOutcome)) == 1) return(NULL) # This will never happens, delete before next release - idxPred <- which(names(modelLDA$prior) %in% predictedOutcome) - splitRes <- lapply(idxPred, function(i) which(names(modelLDA$prior)[i] == predictedOutcome)) - res <- function(datX, missingReference){ - fixedData <- getDataInShape(data = datX, missingReference = missingReference) - predictedProb <- predict(modelLDA, fixedData,type = "prob")[,idxPred, drop = FALSE] - predictedOutcome <- max.col(predictedProb, ties.method = "first") +#' Helper Function for LDA-based Splitting in Tree Construction +#' +#' This function generates a splitting function based on a fitted ULDA model. It +#' assigns observations to the class with the minimal classification cost, and +#' returns the corresponding split results. +#' +#' @noRd +getSplitFunLDAhelper <- function(datX, modelULDA){ + predictedOutcome <- predict(modelULDA, datX) + #> This will never happens, delete before next release + #> unless the Gini trick is abandoned + # if(length(unique(predictedOutcome)) == 1) return(NULL) + idxPred <- which(names(modelULDA$prior) %in% predictedOutcome) # in case some classes are not predicted + splitRes <- lapply(idxPred, function(i) which(names(modelULDA$prior)[i] == predictedOutcome)) + + res <- function(datX){ + predictedProb <- predict(modelULDA, datX, type = "prob")[, idxPred, drop = FALSE] + predictedOutcome <- max.col(-predictedProb %*% t(modelULDA$misClassCost[idxPred, idxPred, drop = FALSE]), ties.method = "first") return(lapply(seq_along(idxPred), function(i) which(i == predictedOutcome))) } diff --git a/R/utils.R b/R/utils.R index 6b3f107..222a7b4 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,458 +1,55 @@ - -# Check for input prior and misClassCost ---------------------------------- - -checkPriorAndMisClassCost <- function(prior, misClassCost, response, internal = FALSE){ - #> Modified from randomForest.default Line 114 - #> if internal == TRUE, the prior is actually prior / Nj - #> And it should have a name attributes. - - freqObs <- table(response, dnn = NULL) / length(response) # Default: Estimated Prior / - - #> prior fix - if (is.null(prior)) { - prior <- freqObs - } else { - if (length(prior) != nlevels(response)) - stop("length of prior not equal to number of classes") - if (!is.null(names(prior))){ - prior <- prior[findTargetIndex(names(prior), levels(response))] - } - if (any(prior < 0)) stop("prior must be non-negative") - } - - #> misClassCost fix - if (!is.null(misClassCost)) { # change the prior - if (dim(misClassCost)[1] != dim(misClassCost)[2] | dim(misClassCost)[1] != nlevels(response)) - stop("misclassification costs matrix has wrong dimension") - if(!all.equal(colnames(misClassCost), rownames(misClassCost))) - stop("misClassCost: colnames should be the same as rownames") - if (!is.null(colnames(misClassCost))){ - misClassCost <- misClassCost[findTargetIndex(colnames(misClassCost), levels(response)), - findTargetIndex(colnames(misClassCost), levels(response))] - } - prior <- prior * apply(misClassCost, 2, sum) - } - if(internal) prior <- prior / freqObs # - if(is.null(names(prior))) names(prior) <- levels(response) - return(prior / sum(prior)) -} - - -findTargetIndex <- function(nameObj, nameTarget){ - #> Assume nameObj and nameTarget are of the same length - #> check if nameObj are all in nameTarget - #> If yes, return the corresponding index - #> so that nameObj[idx] == nameTarget - targetIndex <- match(nameTarget, nameObj) - if (anyNA(targetIndex)) { - stop("The names do not match with the response") - } - return(targetIndex) -} - -getFinalPrior <- function(prior, response){ - priorObs <- table(response, dnn = NULL) / length(response) - levelLeftIdx <- match(names(priorObs), names(prior)) - stopifnot(!anyNA(levelLeftIdx)) # all levels should be in the prior - - prior <- prior[levelLeftIdx] * priorObs - return(prior / sum(prior)) -} - -sampleForLDA <- function(response, prior, K = 1000){ - idxFinal <- seq_along(response) - obsFreq <- table(response, dnn = NULL) - idxSubGroup <- which(obsFreq > K) - for(i in idxSubGroup){ - idxDelete <- sample(which(response == names(obsFreq)[i]), obsFreq[i] - K) - idxFinal <- setdiff(idxFinal, idxDelete) - } - - levelLeftIdx <- match(names(obsFreq), names(prior)) - prior <- prior[levelLeftIdx] * table(response[idxFinal], dnn = NULL) / obsFreq - return(list(idxFinal = idxFinal, prior = prior)) -} - -# Missing Value Imputation ------------------------------------------------ - - -missingFix <- function(data, missingMethod = c("medianFlag", "newLevel")){ - - #> data: a data.frame - #> missingMethod: for numerical / categorical variables, respectively - - misMethod <- misMethodHelper(missingMethod = missingMethod) - data <- createFlagColumns(data = data, misMethod = misMethod) # create flag columns - - numOrNot <- getNumFlag(data) # num or cat - NAcolumns <- sapply(data, anyNA) - dataNRef <- rbind(data, NA) # add ref to the last row of data, initialize using NAs - - for(i in seq_len(ncol(dataNRef))){ - - if(numOrNot[i]){ - # numerical / logical vars - #> The function below output NaN when all entries are NA, - #> so we add an if to prevent a vector of only NA and NaN (not constant) - targetValue <- do.call(misMethod$numMethod, list(dataNRef[,i], na.rm = TRUE)) - missingOrNot <- is.na(dataNRef[,i]) - if(!all(missingOrNot)) dataNRef[which(missingOrNot), i] <- targetValue - }else{ - # categorical vars - if(NAcolumns[i]){ # any NA - dataNRef[,i] <- as.character(dataNRef[,i]) # for new level addition - dataNRef[is.na(dataNRef[,i]),i] <- ifelse(misMethod$catMethod == "newLevel", "new0_0Level", as.character(getMode(dataNRef[,i]))) - dataNRef[,i] <- as.factor(dataNRef[,i]) # level information will be used in prediction - }else{ - dataNRef[,i] <- as.factor(dataNRef[,i]) - dataNRef[nrow(dataNRef),i] <- factor(getMode(dataNRef[,i]), levels = levels(dataNRef[,i])) - } - } - } - - checkColwithAllNA <- sapply(dataNRef, anyNA) # remove columns with all NAs - if(any(checkColwithAllNA)) dataNRef <- dataNRef[, !checkColwithAllNA] - - return(list(data = dataNRef[-nrow(dataNRef),,drop = FALSE], - ref = dataNRef[nrow(dataNRef),,drop = FALSE])) -} - -createFlagColumns <- function(data, misMethod){ - #> given a data and numOrNot (from getNumFlag) - #> output a data with added flag columns with correct 0/1 - #> We only add _FLAG to vars where NAs exist, not all columns - #> since even we add NAflags, they will not be trained - #> Notes: num flags are before cat flags, - #> regardless of their relative column positions - - numOrNot <- getNumFlag(data) # num or cat - NAcolumns <- sapply(data, anyNA) - - if(misMethod$numFlagOrNot & sum(numOrNot) > 0){ - NAcol <- which(numOrNot & NAcolumns) - if(length(NAcol) > 0){ - #> The line below will treat missing flags as numerical variables - #> as.factor() can be applied if we want them to be factors - numFlagCols <- do.call(cbind, sapply(NAcol, function(colIdx) is.na(data[, colIdx])+0, simplify = FALSE)) - colnames(numFlagCols) <- paste(colnames(data)[NAcol], "FLAG", sep = "_") - data <- cbind(data, numFlagCols) - } - } - - if(misMethod$catFlagOrNot & sum(!numOrNot) > 0){ - NAcol <- which((!numOrNot) & NAcolumns) - if(length(NAcol) > 0){ - catFlagCols <- do.call(cbind, sapply(NAcol, function(colIdx) is.na(data[, colIdx])+0, simplify = FALSE)) - colnames(catFlagCols) <- paste(colnames(data)[NAcol], "FLAG", sep = "_") - data <- cbind(data, catFlagCols) - } - } - return(data) -} - -misMethodHelper <- function(missingMethod){ - #> Aim: classify the missing imputation methods based on their methods and flags - numMethod <- missingMethod[1]; catMethod <- missingMethod[2] - numFlagOrNot <- grepl("Flag", numMethod) - catFlagOrNot <- grepl("Flag", catMethod) - numMethod <- ifelse(grepl("mean", numMethod), "mean", "median") - catMethod <- ifelse(grepl("mode", catMethod), "mode", "newLevel") - return(list(numMethod = numMethod, - catMethod = catMethod, - numFlagOrNot = numFlagOrNot, - catFlagOrNot = catFlagOrNot)) -} - - -# constant check --------------------------------------------------------- - -constantColCheck <- function(data, idx, tol = 1e-8, naAction = "keep"){ - if(missing(idx)) idx <- seq_len(ncol(data)) # default output columns - #> constant columns fix: the data in this step should not contains NA - - constantColCheckHelper <- function(x, tol = 1e-8){ - if(getNumFlag(x)) x <- round(x, digits = -log(tol,base = 10)) - if(naAction != "keep") return(length(unique(na.omit(x))) > 1) - return(length(unique(x)) > 1) - } - - idxNotConst <- which(sapply(data, function(x) constantColCheckHelper(x, tol))) - - return(idx[idxNotConst]) -} - - -# Get column Types -------------------------------------------------------- - -getNumFlag <- function(x, index = FALSE){ - #> index decides return type. e.g. return c(1,0,0,1,1,0) (FALSE) or c(1,4,5) (YES) - #> logical has to be included, a column with all NAs has to be viewed as numeric - if(is.null(dim(x))){return(any(class(x) %in% c('numeric', 'integer', 'logical')))} - - stopifnot(is.data.frame(x)) - - numOrNot <- sapply(x, class) %in% c('numeric', 'integer', 'logical') - - if(index){numOrNot <- which(numOrNot)} - - return(numOrNot) -} - -# Get mode -------------------------------------------------------------------- - -getMode <- function(v, prior, posterior = FALSE){ - #> posterior: return posterior probs (TRUE) or mode (FALSE) - #> NA will be ignored - v <- as.factor(v) - if(missing(prior)){ - prior = rep(1,nlevels(v)) # equal prior - }else{ - if (is.null(names(prior))){ - stopifnot(length(prior) == nlevels(v)) - names(prior) <- levels(v) - } else prior <- prior[match(levels(v), names(prior))] - } - - summary_table <- table(v) * prior - if(length(summary_table) == 0) return(NA) - if(posterior){return(summary_table / sum(summary_table))} - return(names(which.max(summary_table))) -} - - - -# Stop check -------------------------------------------------------------- - +#' Check if tree construction should stop +#' +#' @noRd +#' +#' @param responseCurrent A vector of the current response values at the node. +#' @param numCol The number of covariate columns remaining. +#' @param maxTreeLevel The maximum allowed level of the tree. +#' @param minNodeSize The minimum number of observations required at a node. +#' @param currentLevel The current level of the tree being constructed. stopCheck <- function(responseCurrent, numCol, maxTreeLevel, minNodeSize, currentLevel){ - # 0: Normal - # 1: Stop and return posterior majority - # 2: stop and fit LDA - - flagNodeSize <- length(responseCurrent) <= minNodeSize # 数据量不够了,LDA is unstable - flagTreeLevel <- currentLevel >= maxTreeLevel # 层数到了 + flagNodeSize <- length(responseCurrent) <= minNodeSize # Data size too small + flagTreeLevel <- currentLevel >= maxTreeLevel flagCol <- numCol == 0 # no covs left - flagResponse <- length(unique(responseCurrent)) == 1 # 只有一种y - - if (flagResponse | flagCol | flagNodeSize) {return(1)} - if (flagTreeLevel) {return(2)} - return(0) -} - - - -# Get LD scores ----------------------------------------------------------- - -getDesignMatrix <- function(modelLDA, data, scale = FALSE){ - # Output: the design matrix - Terms <- delete.response(modelLDA$terms) - modelX <- model.matrix(Terms, data = data, xlev = modelLDA$xlevels) - - if(scale){ # reserved for the scaling in getLDscores - modelX <- sweep(modelX[,modelLDA$varIdx,drop = FALSE], 2, modelLDA$varCenter, "-") - modelX <- sweep(modelX, 2, modelLDA$varSD, "/") - } - return(modelX) -} - -getLDscores <- function(modelLDA, data, nScores = -1){ - if(anyNA(data)) data <- getDataInShape(data = data, missingReference = modelLDA$misReference) - modelX <- getDesignMatrix(modelLDA = modelLDA, data = data, scale = TRUE) - if(nScores > 0) modelLDA$scaling <- modelLDA$scaling[, seq_len(nScores), drop = FALSE] - LDscores <- modelX %*% modelLDA$scaling - - return(LDscores) -} - - -# New level fix + Missing ----------------------------------------------------------- - -getDataInShape <- function(data, missingReference){ - #> change the shape of test data to the training data - #> and make sure that the dimension of the data is the same as missingRefernce - - cname <- colnames(missingReference) - nameVarIdx <- match(cname, colnames(data)) - if(anyNA(nameVarIdx)){ - #> New columns fix (or Flags): If there are less columns than it should be, - #> add columns with NA - data[,cname[which(is.na(nameVarIdx))]] <- NA - nameVarIdx <- match(cname, colnames(data)) - } - - for(currentIdx in seq_len(ncol(missingReference))){ # Main program starts - #> The tricky part is the iterator is based on the missingReference, NOT the data. - newIdx <- nameVarIdx[currentIdx] - numOrNot <- getNumFlag(missingReference[, currentIdx]) - - ### New-level Fix for Categorical Variable ### - if(!numOrNot) data[, newIdx] <- factor(data[, newIdx], levels = levels(missingReference[, currentIdx])) # may generate NAs - - missingOrNot <- is.na(data[, newIdx]) - if(!any(missingOrNot)) next - - ### Flag Variable Detection ### - #> This part is actually more complicated than expected - #> Four combinations could happen: Ori / Flag both can have NA or complete - currentVarName <- cname[currentIdx] - - ## Scenario 1: It has a related flag variable in the data ## - #> Only modify those flags where the original variable is missing - #> Keep other parts still, since there could already be imputed values - #> in the original variable that have been taken care of - currentFlagIdx <- which(cname == paste(currentVarName,"FLAG",sep = "_")) - if(length(currentFlagIdx) == 1) data[which(missingOrNot), nameVarIdx[currentFlagIdx]] <- 1 - - ## Scenario 2: It is a flag and it has an original variable in (or not in) the data ## - #> Only impute those NAs in the flag, but keep the values that are already in the flag - if(grepl("_FLAG$", currentVarName)){ - orginalVarName <- sub("_FLAG$", "", currentVarName) - orginalVarIdx <- which(cname == orginalVarName) - if(length(orginalVarIdx) == 1){ - data[which(missingOrNot), newIdx] <- is.na(data[which(missingOrNot), nameVarIdx[orginalVarIdx]]) + 0 - } else data[, newIdx] <- 1 # The original data is NOT found - next - } - - ### For numerical & categorical variables ### - data[which(missingOrNot), newIdx] <- missingReference[1, currentIdx] - } - - return(data[,nameVarIdx, drop = FALSE]) -} - - - - -# Prediction in terminal Nodes -------------------------------------------- - -predNode <- function(data, treeeNode, missingReference, type){ - #> data is a data.frame - if(treeeNode$nodeModel == "LDA"){ - data <- getDataInShape(data = data, missingReference = missingReference) - return(predict(object = treeeNode$nodePredict, newdata = data, type = type)) - } else{ - if(type == "response"){ - return(rep(treeeNode$nodePredict, nrow(data))) - } else{ # if type = "all", the extra response column will be added later - pred <- matrix(0,nrow = nrow(data), ncol = length(treeeNode$proportions), dimnames = list(c(), names(treeeNode$proportions))) - pred[,which(treeeNode$nodePredict == colnames(pred))] <- 1 - return(pred) - } + flagResponse <- length(unique(responseCurrent)) == 1 # only one class left + + if (flagResponse | flagCol | flagNodeSize) {return("Insufficient data")} + if (flagTreeLevel) {return("Maximum level reached")} + return("Normal") +} + + +#' Update Prior and Misclassification Cost +#' +#' This function updates the class prior probabilities and misclassification +#' cost matrix based on the observed response distribution. It adjusts the prior +#' and misclassification costs either inside or outside a node, depending on the +#' `insideNode` parameter. +#' +#' @noRd +updatePriorAndMisClassCost <- function(prior, misClassCost, response, insideNode = TRUE){ + if(!insideNode){ # Calculate the relative prior + res <- checkPriorAndMisClassCost(prior = prior, misClassCost = misClassCost, response = response) + priorObs <- as.vector(table(response, dnn = NULL)) / length(response) + res$prior <- res$prior / priorObs + }else{ + priorObs <- table(response, dnn = NULL) / length(response) + levelLeftIdx <- match(names(priorObs), names(prior)) + prior <- prior[levelLeftIdx] * priorObs + res <- list(prior = prior / sum(prior), + misClassCost = misClassCost[levelLeftIdx, levelLeftIdx, drop = FALSE]) } + return(res) } - -# Get the p-value for testing the current nodes' performance -------------- - getOneSidedPvalue <- function(N, lossBefore, lossAfter){ + #> Get the p-value for testing the current split's performance #> H1: lossBefore > lossAfter. loss stands for the prediction error zStat <- (lossBefore - lossAfter) / sqrt((lossBefore * (N - lossBefore) + lossAfter * (N - lossAfter)) / N + 1e-16) - pnorm(zStat, lower.tail = FALSE) -} - - -# Variable Selection ------------------------------------------------------ - -getChiSqStat <- function(datX, y){ - sapply(datX, function(x) getChiSqStatHelper(x, y)) + stats::pnorm(zStat, lower.tail = FALSE) } -getChiSqStatHelper <- function(x,y){ - if(getNumFlag(x)){ # numerical variable: first change to factor - m = mean(x,na.rm = T); s = sd(x,na.rm = T) - if(sum(!is.na(x)) >= 30 * nlevels(y)){ - splitNow = c(m - s *sqrt(3)/2, m, m + s *sqrt(3)/2) - }else splitNow = c(m - s *sqrt(3)/3, m + s *sqrt(3)/3) - - if(length(unique(splitNow)) == 1) return(0) # No possible split - x = cut(x, breaks = c(-Inf, splitNow, Inf), right = TRUE) - } - - if(anyNA(x)){ - levels(x) = c(levels(x), 'newLevel') - x[is.na(x)] <- 'newLevel' - } - if(length(unique(x)) == 1) return(0) # No possible split - - fit <- suppressWarnings(chisq.test(x, y)) +checkPriorAndMisClassCost <- utils::getFromNamespace("checkPriorAndMisClassCost", "folda") +getMode <- utils::getFromNamespace("getMode", "folda") - #> Change to 1-df wilson_hilferty chi-squared stat unless - #> the original df = 1 and p-value is larger than 10^(-16) - ans = unname(ifelse(fit$parameter > 1L, ifelse(fit$p.value > 10^(-16), - qchisq(1-fit$p.value, df = 1), - wilson_hilferty(fit$statistic,fit$parameter)), fit$statistic)) - return(ans) -} - - -wilson_hilferty = function(chi, df){ # change df = K to df = 1 - ans = max(0, (7/9 + sqrt(df) * ( (chi / df) ^ (1/3) - 1 + 2 / (9 * df) ))^3) - return(ans) -} - - -# Get tree depth ---------------------------------------------------------- - -getDepth <- function(treee){ - if(class(treee) == "Treee") treee = treee$treee - depthAll <- numeric(length(treee)) - updateList <- seq_along(treee)[-1] - if(length(updateList) != 0){ - for(i in updateList){ - currentNode <- treee[[i]] - depthAll[i] <- depthAll[currentNode$parent] + 1 - } - } - return(depthAll) -} - - -# RcppEigen --------------------------------------------------------------- - -# library(Rcpp) -# library(RcppEigen) -# Rcpp::cppFunction(' -# Rcpp::List qrEigen(const Eigen::MatrixXd &A) { -# Eigen::HouseholderQR qr(A); -# Eigen::MatrixXd Q = qr.householderQ() * Eigen::MatrixXd::Identity(A.rows(), A.cols()); -# Eigen::MatrixXd R = qr.matrixQR().topRows(A.cols()).template triangularView(); -# return Rcpp::List::create(Rcpp::Named("Q") = Q, Rcpp::Named("R") = R); -# } -# ', depends = "RcppEigen") -# -# Rcpp::cppFunction(' -# Rcpp::List svdEigen(const Eigen::MatrixXd &A) { -# Eigen::BDCSVD svd(A, Eigen::ComputeThinU | Eigen::ComputeThinV); -# Eigen::MatrixXd U = svd.matrixU(); -# Eigen::VectorXd S = svd.singularValues(); -# Eigen::MatrixXd V = svd.matrixV(); -# return Rcpp::List::create(Rcpp::Named("u") = U, Rcpp::Named("d") = S, Rcpp::Named("v") = V); -# } -# ', depends = "RcppEigen") - -saferSVD <- function(x, ...){ - #> Target for error code 1 from Lapack routine 'dgesdd' non-convergence error - #> Current solution: Round the design matrix to make approximations, - #> hopefully this will solve the problem - #> - #> The code is a little lengthy, since the variable assignment in tryCatch is tricky - parList <- list(svdObject = NULL, - svdSuccess = FALSE, - errorDigits = 16, - x = x) - while (!parList$svdSuccess) { - parList <- tryCatch({ - parList$svdObject <- svd(parList$x, ...) - parList$svdSuccess <- TRUE - parList - }, error = function(e) { - if (grepl("error code 1 from Lapack routine 'dgesdd'", e$message)) { - parList$x <- round(x, digits = parList$errorDigits) - parList$errorDigits <- parList$errorDigits - 1 - return(parList) - } else stop(e) - }) - } - return(parList$svdObject) -} diff --git a/README.Rmd b/README.Rmd index 9b6fcc0..cc86be3 100644 --- a/README.Rmd +++ b/README.Rmd @@ -18,6 +18,7 @@ knitr::opts_chunk$set( [![CRAN status](https://www.r-pkg.org/badges/version/LDATree)](https://CRAN.R-project.org/package=LDATree) [![R-CMD-check](https://github.com/Moran79/LDATree/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/Moran79/LDATree/actions/workflows/R-CMD-check.yaml) +![CRAN Downloads](https://cranlogs.r-pkg.org/badges/grand-total/LDATree) `LDATree` is an R modeling package for fitting classification trees. If you are unfamiliar with classification trees, here is a [tutorial](http://www.sthda.com/english/articles/35-statistical-machine-learning-essentials/141-cart-model-decision-tree-essentials/) about the traditional CART and its R implementation `rpart`. @@ -40,7 +41,7 @@ Compared to other similar trees, `LDATree` sets itself apart in the following wa install.packages("LDATree") ``` -The CRAN version is an outdated one from 08/2023. As of 06/2024, please use the command below for the current version, and the official updated CRAN release will be coming soon! +The CRAN version is an outdated one from 08/2023. Please stay tune for the latest version, which will be released around 10/2024. Meanwhile, feel free to try the undocumented version bellow. ```{r,fig.asp=0.618,out.width = "80%",fig.align = "center", eval=FALSE} library(devtools) @@ -54,9 +55,9 @@ To build an LDATree: ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center"} library(LDATree) set.seed(443) -mpg <- as.data.frame(ggplot2::mpg) -datX <- mpg[, -5] # All predictors without Y -response <- mpg[, 5] # we try to predict "cyl" (number of cylinders) +diamonds <- as.data.frame(ggplot2::diamonds)[sample(53940, 2000),] +datX <- diamonds[, -2] +response <- diamonds[, 2] # we try to predict "cut" fit <- Treee(datX = datX, response = response, verbose = FALSE) ``` diff --git a/README.md b/README.md index 395514d..a0e9cab 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ [![CRAN status](https://www.r-pkg.org/badges/version/LDATree)](https://CRAN.R-project.org/package=LDATree) [![R-CMD-check](https://github.com/Moran79/LDATree/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/Moran79/LDATree/actions/workflows/R-CMD-check.yaml) +![CRAN Downloads](https://cranlogs.r-pkg.org/badges/grand-total/LDATree) `LDATree` is an R modeling package for fitting classification trees. If @@ -41,9 +42,9 @@ following ways: install.packages("LDATree") ``` -The CRAN version is an outdated one from 08/2023. As of 06/2024, please -use the command below for the current version, and the official updated -CRAN release will be coming soon! +The CRAN version is an outdated one from 08/2023. Please stay tune for +the latest version, which will be released around 10/2024. Meanwhile, +feel free to try the undocumented version bellow. ``` r library(devtools) @@ -57,9 +58,9 @@ To build an LDATree: ``` r library(LDATree) set.seed(443) -mpg <- as.data.frame(ggplot2::mpg) -datX <- mpg[, -5] # All predictors without Y -response <- mpg[, 5] # we try to predict "cyl" (number of cylinders) +diamonds <- as.data.frame(ggplot2::diamonds)[sample(53940, 2000),] +datX <- diamonds[, -2] +response <- diamonds[, 2] # we try to predict "cut" fit <- Treee(datX = datX, response = response, verbose = FALSE) ``` @@ -92,7 +93,7 @@ plot(fit, datX = datX, response = response, node = 3) # 3. A message plot(fit, datX = datX, response = response, node = 2) -#> [1] "Every observation in this node is predicted to be 4" +#> [1] "Every observation in node 2 is predicted to be Fair" ``` To make predictions: @@ -101,20 +102,20 @@ To make predictions: # Prediction only. predictions <- predict(fit, datX) head(predictions) -#> [1] "4" "4" "4" "4" "6" "6" +#> [1] "Ideal" "Ideal" "Ideal" "Ideal" "Ideal" "Ideal" ``` ``` r # A more informative prediction predictions <- predict(fit, datX, type = "all") head(predictions) -#> response node 4 5 6 8 -#> 1 4 14 1 0 0 0 -#> 2 4 6 1 0 0 0 -#> 3 4 6 1 0 0 0 -#> 4 4 6 1 0 0 0 -#> 5 6 18 0 0 1 0 -#> 6 6 15 0 0 1 0 +#> response node Fair Good Very Good Premium Ideal +#> 1 Ideal 6 4.362048e-03 0.062196349 0.2601145 0.056664046 0.6166630 +#> 2 Ideal 6 1.082022e-04 0.006308281 0.1290079 0.079961227 0.7846144 +#> 3 Ideal 6 7.226446e-03 0.077434549 0.2036148 0.023888946 0.6878352 +#> 4 Ideal 6 1.695119e-02 0.115233616 0.1551836 0.008302145 0.7043295 +#> 5 Ideal 6 4.923729e-05 0.004157352 0.1498265 0.187391975 0.6585749 +#> 6 Ideal 6 4.827312e-03 0.061274797 0.1978061 0.027410359 0.7086815 ``` ## Getting help diff --git a/man/Treee.Rd b/man/Treee.Rd index 6532cbe..fa020d7 100644 --- a/man/Treee.Rd +++ b/man/Treee.Rd @@ -2,99 +2,122 @@ % Please edit documentation in R/Treee.R \name{Treee} \alias{Treee} -\title{Classification trees with Linear Discriminant Analysis terminal nodes} +\title{Classification Trees with Uncorrelated Linear Discriminant Analysis Terminal +Nodes} \usage{ Treee( datX, response, - ldaType = c("step", "all"), - nodeModel = c("LDA", "mode"), - missingMethod = c("medianFlag", "newLevel"), - prior = NULL, - misClassCost = NULL, - pruneMethod = c("post", "pre", "pre-post"), - numberOfPruning = 10, + ldaType = c("forward", "all"), + nodeModel = c("ULDA", "mode"), + pruneMethod = c("pre", "post"), + numberOfPruning = 10L, maxTreeLevel = 20L, minNodeSize = NULL, - pThreshold = 0.1, - verbose = TRUE, - kSample = 1e+07 + pThreshold = NULL, + prior = NULL, + misClassCost = NULL, + missingMethod = c("medianFlag", "newLevel"), + kSample = -1, + verbose = TRUE ) } \arguments{ -\item{missingMethod}{Missing value solutions for numerical variables and -factor variables. \code{'mean'}, \code{'median'}, \code{'meanFlag'}, \code{'medianFlag'} are -available for numerical variables. \code{'mode'}, \code{'modeFlag'}, \code{'newLevel'} are -available for factor variables. The word \code{'Flag'} in the methods indicates -whether a missing flag is added or not. The \code{'newLevel'} method means that -all missing values are replaced with a new level rather than imputing them -to another existing value.} +\item{datX}{A data frame of predictor variables.} + +\item{response}{A vector of response values corresponding to \code{datX}.} + +\item{ldaType}{A character string specifying the type of LDA to use. Options +are \code{"forward"} for forward ULDA or \code{"all"} for full ULDA. Default is +\code{"forward"}.} + +\item{nodeModel}{A character string specifying the type of model used in each +node. Options are \code{"ULDA"} for Uncorrelated LDA, or \code{"mode"} for predicting +based on the most frequent class. Default is \code{"ULDA"}.} + +\item{pruneMethod}{A character string specifying the pruning method. \code{"pre"} +performs pre-pruning based on p-value thresholds, and \code{"post"} performs +cross-validation-based post-pruning. Default is \code{"pre"}.} -\item{maxTreeLevel}{controls the largest tree size possible for either a -direct-stopping tree or a CV-pruned tree. Adding one extra level (depth) -introduces an additional layer of nodes at the bottom of the current tree. -e.g., when the maximum level is 1 (or 2), the maximum tree size is 3 (or -7).} +\item{numberOfPruning}{An integer specifying the number of folds for +cross-validation during post-pruning. Default is \code{10}.} -\item{minNodeSize}{controls the minimum node size. Think carefully before -changing this value. Setting a large number might result in early stopping -and reduced accuracy. By default, it's set to one plus the number of -classes in the response variable.} +\item{maxTreeLevel}{An integer controlling the maximum depth of the tree. +Increasing this value allows for deeper trees with more nodes. Default is +\code{20}.} -\item{verbose}{a logical. If TRUE, the function provides additional -diagnostic messages or detailed output about its progress or internal -workings. Default is FALSE, where the function runs silently without -additional output.} +\item{minNodeSize}{An integer controlling the minimum number of samples +required in a node. Setting a higher value may lead to earlier stopping and +smaller trees. If not specified, it defaults to one plus the number of +response classes.} -\item{formula}{an object of class \link{formula}, which has the form \code{class ~ x1 + x2 + ...}} +\item{pThreshold}{A numeric value used as a threshold for pre-pruning based +on p-values. Lower values result in more conservative trees. If not +specified, defaults to \code{0.01} for pre-pruning and \code{0.51} for post-pruning.} -\item{data}{a data frame that contains both predictors and the response. -Missing values are allowed in predictors but not in the response.} +\item{prior}{A numeric vector of prior probabilities for each class. If +\code{NULL}, the prior is automatically calculated from the data.} + +\item{misClassCost}{A square matrix \eqn{C}, where each element \eqn{C_{ij}} +represents the cost of classifying an observation into class \eqn{i} given +that it truly belongs to class \eqn{j}. If \code{NULL}, a default matrix with +equal misclassification costs for all class pairs is used. Default is +\code{NULL}.} + +\item{missingMethod}{A character string specifying how missing values should +be handled. Options include \code{'mean'}, \code{'median'}, \code{'meanFlag'}, +\code{'medianFlag'} for numerical variables, and \code{'mode'}, \code{'modeFlag'}, +\code{'newLevel'} for factor variables. \code{'Flag'} options indicate whether a +missing flag is added, while \code{'newLevel'} replaces missing values with a +new factor level.} + +\item{kSample}{An integer specifying the number of samples to use for +downsampling during tree construction. Set to \code{-1} to disable downsampling.} + +\item{verbose}{A logical value. If \code{TRUE}, progress messages and detailed +output are printed during tree construction and pruning. Default is +\code{FALSE}.} } \value{ -An object of class \code{Treee} containing the following components: +An object of class \code{Treee} containing the fitted tree, which is a +list of nodes, each an object of class \code{TreeeNode}. Each \code{TreeeNode} +contains: \itemize{ -\item \code{formula}: the formula passed to the \code{\link[=Treee]{Treee()}} -\item \code{treee}: a list of all the tree nodes, and each node is an object of class \code{TreeeNode}. -\item \code{missingMethod}: the missingMethod passed to the \code{\link[=Treee]{Treee()}} - -An object of class \code{TreeeNode} containing the following components: -\item \code{currentIndex}: the node index of the current node -\item \code{currentLevel}: the level of the current node in the tree -\item \code{idxRow}, \code{idxCol}: the row and column indices showing which portion of data is used in the current node -\item \code{currentLoss}: ? -\item \code{accuracy}: the training accuracy of the current node -\item \code{stopFlag}: ? -\item \code{proportions}: shows the observed frequency for each class -\item \code{parent}: the node index of its parent -\item \code{children}: the node indices of its direct children (not including its children's children) -\item \code{misReference}: a data frame, serves as the reference for missing value imputation -\item \code{splitFun}: ? -\item \code{nodeModel}: one of \code{'mode'} or \code{'LDA'}. It shows the type of predictive model fitted in the current node -\item \code{nodePredict}: the fitted predictive model in the current node. It is an object of class \code{ldaGSVD} if LDA is fitted. If \code{nodeModel = 'mode'}, then it is a vector of length one, showing the plurality class. +\item \code{currentIndex}: The node index in the tree. +\item \code{currentLevel}: The depth of the current node in the tree. +\item \code{idxRow}, \code{idxCol}: Row and column indices indicating which part of the original data was used for this node. +\item \code{currentLoss}: The training error for this node. +\item \code{accuracy}: The training accuracy for this node. +\item \code{stopInfo}: Information on why the node stopped growing. +\item \code{proportions}: The observed frequency of each class in this node. +\item \code{prior}: The (adjusted) class prior probabilities used for ULDA or mode prediction. +\item \code{misClassCost}: The misclassification cost matrix used in this node. +\item \code{parent}: The index of the parent node. +\item \code{children}: A vector of indices of this node’s direct children. +\item \code{splitFun}: The splitting function used for this node. +\item \code{nodeModel}: Indicates the model fitted at the node (\code{'ULDA'} or \code{'mode'}). +\item \code{nodePredict}: The fitted model at the node, either a ULDA object or the plurality class. +\item \code{alpha}: The p-value from a two-sample t-test used to evaluate the strength of the split. +\item \code{childrenTerminal}: A vector of indices representing the terminal nodes that are descendants of this node. +\item \code{childrenTerminalLoss}: The total training error accumulated from all nodes listed in \code{childrenTerminal}. } } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Fit an LDATree model. -} -\details{ -Unlike other classification trees, LDATree integrates LDA throughout the -entire tree-growing process. Here is a breakdown of its distinctive features: -\itemize{ -\item The tree searches for the best binary split based on sample quantiles of the first linear discriminant score. -\item An LDA/GSVD model is fitted for each terminal node (For more details, refer to \code{\link[=ldaGSVD]{ldaGSVD()}}). -\item Missing values can be imputed using the mean, median, or mode, with optional missing flags available. -\item By default, the tree employs a direct-stopping rule. However, cross-validation using the alpha-pruning from CART is also provided. -} +This function fits a classification tree where each node has a Uncorrelated +Linear Discriminant Analysis (ULDA) model. It can also handle missing values +and perform downsampling. The resulting tree can be pruned either through +pre-pruning or post-pruning methods. } \examples{ -fit <- Treee(Species~., data = iris) +fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) # Use cross-validation to prune the tree -fitCV <- Treee(Species~., data = iris) -# prediction -predict(fit,iris) -# plot the overall tree -plot(fit) -# plot a certain node -plot(fit, iris, node = 1) +fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE) +head(predict(fit, iris)) # prediction +plot(fit) # plot the overall tree +plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node +} +\references{ +Wang, S. (2024). A New Forward Discriminant Analysis Framework +Based On Pillai's Trace and ULDA. \emph{arXiv preprint arXiv:2409.03136}. +Available at \url{https://arxiv.org/abs/2409.03136}. } diff --git a/man/figures/README-plot1-1.png b/man/figures/README-plot1-1.png index e38f4c0..8b1748e 100644 Binary files a/man/figures/README-plot1-1.png and b/man/figures/README-plot1-1.png differ diff --git a/man/figures/README-plot2-1.png b/man/figures/README-plot2-1.png index 731555d..a03a3dd 100644 Binary files a/man/figures/README-plot2-1.png and b/man/figures/README-plot2-1.png differ diff --git a/man/figures/README-plot2-2.png b/man/figures/README-plot2-2.png index 99edc66..d19b33a 100644 Binary files a/man/figures/README-plot2-2.png and b/man/figures/README-plot2-2.png differ diff --git a/man/figures/lifecycle-archived.svg b/man/figures/lifecycle-archived.svg deleted file mode 100644 index 745ab0c..0000000 --- a/man/figures/lifecycle-archived.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: archived - - - - - - - - - - - - - - - lifecycle - - archived - - diff --git a/man/figures/lifecycle-defunct.svg b/man/figures/lifecycle-defunct.svg deleted file mode 100644 index d5c9559..0000000 --- a/man/figures/lifecycle-defunct.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: defunct - - - - - - - - - - - - - - - lifecycle - - defunct - - diff --git a/man/figures/lifecycle-deprecated.svg b/man/figures/lifecycle-deprecated.svg deleted file mode 100644 index b61c57c..0000000 --- a/man/figures/lifecycle-deprecated.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: deprecated - - - - - - - - - - - - - - - lifecycle - - deprecated - - diff --git a/man/figures/lifecycle-experimental.svg b/man/figures/lifecycle-experimental.svg deleted file mode 100644 index 5d88fc2..0000000 --- a/man/figures/lifecycle-experimental.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: experimental - - - - - - - - - - - - - - - lifecycle - - experimental - - diff --git a/man/figures/lifecycle-maturing.svg b/man/figures/lifecycle-maturing.svg deleted file mode 100644 index 897370e..0000000 --- a/man/figures/lifecycle-maturing.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: maturing - - - - - - - - - - - - - - - lifecycle - - maturing - - diff --git a/man/figures/lifecycle-questioning.svg b/man/figures/lifecycle-questioning.svg deleted file mode 100644 index 7c1721d..0000000 --- a/man/figures/lifecycle-questioning.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: questioning - - - - - - - - - - - - - - - lifecycle - - questioning - - diff --git a/man/figures/lifecycle-soft-deprecated.svg b/man/figures/lifecycle-soft-deprecated.svg deleted file mode 100644 index 9c166ff..0000000 --- a/man/figures/lifecycle-soft-deprecated.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: soft-deprecated - - - - - - - - - - - - - - - lifecycle - - soft-deprecated - - diff --git a/man/figures/lifecycle-stable.svg b/man/figures/lifecycle-stable.svg deleted file mode 100644 index 9bf21e7..0000000 --- a/man/figures/lifecycle-stable.svg +++ /dev/null @@ -1,29 +0,0 @@ - - lifecycle: stable - - - - - - - - - - - - - - - - lifecycle - - - - stable - - - diff --git a/man/figures/lifecycle-superseded.svg b/man/figures/lifecycle-superseded.svg deleted file mode 100644 index db8d757..0000000 --- a/man/figures/lifecycle-superseded.svg +++ /dev/null @@ -1,21 +0,0 @@ - - lifecycle: superseded - - - - - - - - - - - - - - - lifecycle - - superseded - - diff --git a/man/ldaGSVD.Rd b/man/ldaGSVD.Rd deleted file mode 100644 index 104376c..0000000 --- a/man/ldaGSVD.Rd +++ /dev/null @@ -1,62 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/ldaGSVD.R -\name{ldaGSVD} -\alias{ldaGSVD} -\title{Linear Discriminant Analysis using the Generalized Singular Value -Decomposition} -\usage{ -ldaGSVD( - datX, - response, - method = c("all", "step"), - fixNA = TRUE, - missingMethod = c("medianFlag", "newLevel"), - prior = NULL, - misClassCost = NULL, - insideTree = FALSE -) -} -\arguments{ -\item{method}{default to be all} - -\item{data}{a data frame that contains both predictors and the response. -Missing values are NOT allowed.} -} -\value{ -An object of class \code{ldaGSVD} containing the following components: -\itemize{ -\item \code{scaling}: a matrix which transforms the training data to LD scores, normalized so that the within-group scatter matrix is proportional to the identity matrix. -\item \code{formula}: the formula passed to the \code{\link[=ldaGSVD]{ldaGSVD()}} -\item \code{terms}: a object of class \code{terms} derived using the input \code{formula} and the training data -\item \code{prior}: a \code{table} of the estimated prior probabilities. -\item \code{groupMeans}: a matrix that records the group means of the training data on the transformed LD scores. -\item \code{xlevels}: a list records the levels of the factor predictors, derived using the input \code{formula} and the training data -} -} -\description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Fit an LDA/GSVD model. -} -\details{ -Traditional Fisher's Linear Discriminant Analysis (LDA) ceases to work when -the within-class scatter matrix is singular. The Generalized Singular Value -Decomposition (GSVD) is used to address this issue. GSVD simultaneously -diagonalizes both the within-class and between-class scatter matrices without -the need to invert a singular matrix. This method is believed to be more -accurate than PCA-LDA (as in \code{MASS::lda}) because it also considers the -information in the between-class scatter matrix. -} -\examples{ -fit <- ldaGSVD(Species~., data = iris) -# prediction -predict(fit,iris) -} -\references{ -Ye, J., Janardan, R., Park, C. H., & Park, H. (2004). \emph{An -optimization criterion for generalized discriminant analysis on -undersampled problems}. IEEE Transactions on Pattern Analysis and Machine -Intelligence - -Howland, P., Jeon, M., & Park, H. (2003). \emph{Structure preserving dimension -reduction for clustered text data based on the generalized singular value -decomposition}. SIAM Journal on Matrix Analysis and Applications -} diff --git a/man/plot.Treee.Rd b/man/plot.Treee.Rd index 1ce7b2f..f09d1ee 100644 --- a/man/plot.Treee.Rd +++ b/man/plot.Treee.Rd @@ -2,59 +2,58 @@ % Please edit documentation in R/plot.R \name{plot.Treee} \alias{plot.Treee} -\title{Plot a Treee object} +\title{Plot a Decision Tree or Specific Node} \usage{ -\method{plot}{Treee}(tree, datX, response, node = -1, ...) +\method{plot}{Treee}(x, datX, response, node = -1, ...) } \arguments{ -\item{node}{the node index that you are interested in. By default, it is set -to \code{-1} and the overall tree structure is drawn.} +\item{x}{A fitted model object of class \code{Treee}, typically the result of the +\code{\link[=Treee]{Treee()}} function.} -\item{...}{further arguments passed to or from other methods.} +\item{datX}{A data frame of predictor variables. Required for plotting +individual nodes.} -\item{x}{a fitted model object of class \code{Treee}, which is assumed to be the -result of the \code{\link[=Treee]{Treee()}} function.} +\item{response}{A vector of response values. Required for plotting individual +nodes.} -\item{data}{the original data you used to fit the \code{Treee} object if you want -the individual plot for each node. Otherwise, you can leave this parameter -blank if you only need the overall tree structure diagram.} +\item{node}{An integer specifying the node to plot. If \code{node = -1}, the +entire tree is plotted. Default is \code{-1}.} + +\item{...}{Additional arguments passed to the plotting functions.} } \value{ -For overall tree structure (\code{node = -1}), A figure of class -\code{visNetwork} is drawn. Otherwise, a figure of class \code{ggplot} is drawn. +A \code{visNetwork} interactive plot of the decision tree if \code{node = -1}, +or a \code{ggplot2} object if a specific node is plotted. } \description{ -Provide a diagram of the whole tree structure or a scatter/density plot for a -specific tree node. +This function visualizes either the entire decision tree or a specific node +within the tree. The tree is displayed as an interactive network of nodes and +edges, while individual nodes are scatter/density plots using \code{ggplot2}. } -\section{Overall tree structure}{ +\section{Overall Tree Structure}{ -A full tree diagram (via the R package \link{visNetwork}) is shown if \code{node} is -not provided (default is \code{-1}). The color shows the most common (plurality) -class inside each node. The size of each terminal node is based on its -relative sample size. Under every node, you see the plurality class, the -fraction of the correctly predicted training sample vs. the node's sample -size, and the node index, respectively. When you click on the node, an -information panel with more details will appear. +A full tree diagram is displayed using \link{visNetwork} when \code{node} is not +specified (the default is \code{-1}). The color represents the most common +(plurality) class within each node, and the size of each terminal node +reflects its relative sample size. Below each node, the fraction of +correctly predicted training samples and the total sample size for that +node are shown, along with the node index. Clicking on a node opens an +information panel with additional details. } -\section{Individual plot for each node}{ +\section{Individual Node Plot}{ -The node index and the original training data are required to return a more -detailed plot within a specific node. The density plot will be provided -when only two levels are left for the response variable in a node (like in -a binary classification problem). Samples are projected down to their first -linear discriminant scores (LD1). A scatter plot will be provided if a node -contains more than two classes. Samples are projected down to their first -and second linear discriminant scores. +To plot a specific node, you must provide the node index along with the +original training predictors (\code{datX}) and responses (\code{response}). A scatter +plot is generated if more than one discriminant score is available, +otherwise, a density plot is created. Samples are projected onto their +linear discriminant score(s). } \examples{ -fit <- Treee(Species~., data = iris) -# plot the overall tree -plot(fit) -# plot a certain node -plot(fit, iris, node = 1) +fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) +plot(fit) # plot the overall tree +plot(fit, datX = iris, response = iris[, 5], node = 1) # plot a specific node } diff --git a/man/predict.Treee.Rd b/man/predict.Treee.Rd index 3cdb836..bb796f7 100644 --- a/man/predict.Treee.Rd +++ b/man/predict.Treee.Rd @@ -4,47 +4,44 @@ \alias{predict.Treee} \title{Predictions from a fitted Treee object} \usage{ -\method{predict}{Treee}( - object, - newdata, - type = c("response", "prob", "all"), - insideCV = FALSE, - newY = NULL, - ... -) +\method{predict}{Treee}(object, newdata, type = c("response", "prob", "all"), ...) } \arguments{ -\item{object}{a fitted model object of class \code{Treee}, which is assumed to be -the result of the \code{\link[=Treee]{Treee()}} function.} +\item{object}{A fitted model object of class \code{Treee}, typically the result of +the \code{\link[=Treee]{Treee()}} function.} -\item{newdata}{data frame containing the values at which predictions are -required. Missing values are allowed.} +\item{newdata}{A data frame containing the predictor variables. Missing +values are allowed and will be handled according to the fitted tree's +method for handling missing data.} -\item{type}{character string denoting the type of predicted value returned. -The default is to return the predicted class (\code{type = 'response'}). The -predicted posterior probabilities for each class will be returned if \code{type = 'prob'}. \code{'all'} returns a data frame with predicted classes, posterior -probabilities, and the predicted node indices.} +\item{type}{A character string specifying the type of prediction to return. +Options are: +\itemize{ +\item \code{'response'}: returns the predicted class for each observation (default). +\item \code{'prob'}: returns a data frame of posterior probabilities for each class. +\item \code{'all'}: returns a data frame containing predicted classes, posterior probabilities, and the predicted node indices. +}} -\item{...}{further arguments passed to or from other methods.} +\item{...}{Additional arguments passed to or from other methods.} } \value{ -The function returns different values based on the \code{type}, if +Depending on the value of \code{type}, the function returns: \itemize{ -\item \code{type = 'response'}: vector of predicted responses. -\item \code{type = 'prob'}: a data frame of the posterior probabilities. Each class takes a column. -\item \code{type = 'all'}: a data frame contains the predicted responses, posterior probabilities, and the predicted node indices. +\item If \code{type = 'response'}: A character vector of predicted class labels. +\item If \code{type = 'prob'}: A data frame of posterior probabilities, where each class has its own column. +\item If \code{type = 'all'}: A data frame containing predicted class labels, posterior probabilities, and the predicted node indices. } -Note: for factor predictors, if it contains a level which is not used to -grow the tree, it will be converted to missing and will be imputed according -to the \code{missingMethod} in the fitted tree. +Note: For factor predictors, if a level not present in the training data is +found in \code{newdata}, it will be treated as missing and handled according to +the \code{missingMethod} specified in the fitted tree. } \description{ -Prediction of test data using a fitted Treee object +Generate predictions on new data using a fitted \code{Treee} model. } \examples{ -fit <- Treee(Species~., data = iris) -predict(fit,iris) -# output prosterior probabilities -predict(fit,iris,type = "prob") +fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) +head(predict(fit, iris)) # Predicted classes +head(predict(fit, iris[, -5], type = "prob")) # Posterior probabilities +head(predict(fit, iris[, -5], type = "all")) # Full details } diff --git a/man/predict.ldaGSVD.Rd b/man/predict.ldaGSVD.Rd deleted file mode 100644 index b1836fb..0000000 --- a/man/predict.ldaGSVD.Rd +++ /dev/null @@ -1,60 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/ldaGSVD.R -\name{predict.ldaGSVD} -\alias{predict.ldaGSVD} -\title{Predictions from a fitted ldaGSVD object} -\usage{ -\method{predict}{ldaGSVD}(object, newdata, type = c("response", "prob"), ...) -} -\arguments{ -\item{object}{a fitted model object of class \code{ldaGSVD}, which is assumed to -be the result of the \code{\link[=ldaGSVD]{ldaGSVD()}} function.} - -\item{newdata}{data frame containing the values at which predictions are -required. Missing values are NOT allowed.} - -\item{type}{character string denoting the type of predicted value returned. -The default is to return the predicted class (\code{type} = 'response'). The -predicted posterior probabilities for each class will be returned if \code{type} -= 'prob'.} - -\item{...}{further arguments passed to or from other methods.} -} -\value{ -The function returns different values based on the \code{type}, if -\itemize{ -\item \code{type = 'response'}: vector of predicted responses. -\item \code{type = 'prob'}: a data frame of the posterior probabilities. Each class takes a column. -} -} -\description{ -Prediction of test data using a fitted ldaGSVD object -} -\details{ -Unlike the original paper, which uses the k-nearest neighbor (k-NN) as the -classifier, we use a faster and more straightforward likelihood-based method. -One limitation of the traditional likelihood-based method for LDA is that it -ceases to work when there are Linear Discriminant (LD) directions with zero -variance in the within-class scatter matrix. However, when using LDA/GSVD, -all chosen LD directions possess non-zero variance in the between-class -scatter matrix. This implies that LD directions with zero variance in the -within-class scatter matrix will yield the highest Fisher's ratio. Therefore, -to get these directions higher weights, we manually adjust the zero variance -to \code{1e-15} for computational reasons. -} -\examples{ -fit <- ldaGSVD(Species~., data = iris) -predict(fit,iris) -# output prosterior probabilities -predict(fit,iris,type = "prob") -} -\references{ -Ye, J., Janardan, R., Park, C. H., & Park, H. (2004). \emph{An -optimization criterion for generalized discriminant analysis on -undersampled problems}. IEEE Transactions on Pattern Analysis and Machine -Intelligence - -Howland, P., Jeon, M., & Park, H. (2003). \emph{Structure preserving dimension -reduction for clustered text data based on the generalized singular value -decomposition}. SIAM Journal on Matrix Analysis and Applications -} diff --git a/tests/testthat/test-Treee.R b/tests/testthat/test-Treee.R index 7860cee..0bd99ab 100644 --- a/tests/testthat/test-Treee.R +++ b/tests/testthat/test-Treee.R @@ -1,3 +1,7 @@ -test_that("works on iris data", { - expect_equal(predict(Treee(Species~., data = iris),iris)[c(1,51,101)], c("setosa", "versicolor", "virginica")) +test_that("folda: work on tibble", { + skip_on_cran() + dat <- ggplot2::diamonds[1:100,] + fit <- Treee(dat[, -2], response = dat[[2]], verbose = FALSE) + result <- predict(fit, dat) + expect_equal(result[1:4], c("Very Good", "Ideal", "Ideal", "Premium")) }) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R new file mode 100644 index 0000000..29571ac --- /dev/null +++ b/tests/testthat/test-utils.R @@ -0,0 +1,32 @@ +test_that("updatePriorAndMisClassCost works", { + skip_on_cran() + + # Simple case + test1 <- updatePriorAndMisClassCost(prior = NULL, + misClassCost = NULL, + response = factor(LETTERS[c(1,1,1,2,2,3)]), + insideNode = FALSE) + expect1 <- list(prior = c(A = 1, B = 1, C = 1), misClassCost = structure( + c(0, 1, 1, 1, 0, 1, 1, 1, 0), dim = c(3L, 3L), dimnames = list(c("A", "B", "C"), c("A", "B", "C")))) + expect_equal(test1, expect1) + + # prior != obs + test2 <- updatePriorAndMisClassCost(prior = c(1,1,2), + misClassCost = NULL, + response = factor(LETTERS[c(1,1,1,2,2,3)]), + insideNode = FALSE) + expect_equal(length(unique(test2$prior / c(2,3,12))), 1) + + # Subset the classes + priorAndMisClassCost <- list(prior = structure(c(A = 1, B = 2, C = 3), class = "table", dim = 3L, dimnames = list( + c("A", "B", "C"))), misClassCost = structure(c(0, 1, 2, 3, 0, 4, 5, 6, 0), dim = c(3L, 3L), + dimnames = list(c("A", "B", "C"), c("A", "B", "C")))) + test3 <- updatePriorAndMisClassCost(prior = priorAndMisClassCost$prior, + misClassCost = priorAndMisClassCost$misClassCost, + response = factor(LETTERS[c(1,1,1,3)]), + insideNode = TRUE) + expect3 <- list(prior = structure(c(A = 0.5, C = 0.5), class = "table", dim = 2L, dimnames = list( + c("A", "C"))), misClassCost = structure(c(0, 2, 5, 0), dim = c(2L, + 2L), dimnames = list(c("A", "C"), c("A", "C")))) + expect_equal(test3, expect3) +}) diff --git a/vignettes/LDATree.Rmd b/vignettes/LDATree.Rmd index 9457a45..985ff58 100644 --- a/vignettes/LDATree.Rmd +++ b/vignettes/LDATree.Rmd @@ -36,16 +36,16 @@ Currently, `LDATree` offers two methods to construct a tree: 1. The second approach involves pruning: it permits the building of a larger tree, which is then pruned using cross-validation. ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE} -mpg <- as.data.frame(ggplot2::mpg) -datX <- mpg[, -5] # All predictors without Y -response <- mpg[, 5] # we try to predict "cyl" (number of cylinders) - -# Build a tree using direct-stopping rule -fit <- Treee(datX = datX, response = response, pruneMethod = "pre", verbose = FALSE) - -# Build a tree using cross-validation -set.seed(443) -fitCV <- Treee(datX = datX, response = response, pruneMethod = "post", verbose = FALSE) +# mpg <- as.data.frame(ggplot2::mpg) +# datX <- mpg[, -5] # All predictors without Y +# response <- mpg[, 5] # we try to predict "cyl" (number of cylinders) +# +# # Build a tree using direct-stopping rule +# fit <- Treee(datX = datX, response = response, pruneMethod = "pre", verbose = FALSE) +# +# # Build a tree using cross-validation +# set.seed(443) +# fitCV <- Treee(datX = datX, response = response, pruneMethod = "post", verbose = FALSE) ``` # Plot the Tree @@ -67,28 +67,28 @@ fitCV <- Treee(datX = datX, response = response, pruneMethod = "post", verbose = ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE} # Three types of individual plots -# 1. Scatter plot on first two LD scores -plot(fit, datX = datX, response = response, node = 1) - -# 2. Density plot on the first LD score -plot(fit, datX = datX, response = response, node = 3) - -# 3. A message -plot(fit, datX = datX, response = response, node = 2) +# # 1. Scatter plot on first two LD scores +# plot(fit, datX = datX, response = response, node = 1) +# +# # 2. Density plot on the first LD score +# plot(fit, datX = datX, response = response, node = 3) +# +# # 3. A message +# plot(fit, datX = datX, response = response, node = 2) ``` # Make Predictions ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE} # Prediction only -predictions <- predict(fit, datX) -head(predictions) +# predictions <- predict(fit, datX) +# head(predictions) ``` ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE} # A more informative prediction -predictions <- predict(fit, datX, type = "all") -head(predictions) +# predictions <- predict(fit, datX, type = "all") +# head(predictions) ``` # Missing Values @@ -97,10 +97,10 @@ For missing values, you do not need to specify anything (unless you want to); `L ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE} # -datXmissing <- datX -for(i in 1:10) datXmissing[sample(234,20),i] <- NA -fitMissing <- Treee(datX = datXmissing, response = response, pruneMethod = "post", verbose = FALSE) -plot(fitMissing, datX = datXmissing, response = response, node = 1) +# datXmissing <- datX +# for(i in 1:10) datXmissing[sample(234,20),i] <- NA +# fitMissing <- Treee(datX = datXmissing, response = response, pruneMethod = "post", verbose = FALSE) +# plot(fitMissing, datX = datXmissing, response = response, node = 1) ``` # LDA/GSVD @@ -108,9 +108,9 @@ plot(fitMissing, datX = datXmissing, response = response, node = 1) As we re-implement the LDA/GSVD and apply it in the model fitting, a by-product is the `ldaGSVD` function. Feel free to play with it and see how it compares to `MASS::lda`. ```{r,fig.asp=0.618,out.width = "100%",fig.align = "center", echo=TRUE} -fitLDAgsvd <- ldaGSVD(datX = datX, response = response) -predictionsLDAgsvd <- predict(fitLDAgsvd, newdata = datX) -mean(predictionsLDAgsvd == response) # Training error +# fitLDAgsvd <- ldaGSVD(datX = datX, response = response) +# predictionsLDAgsvd <- predict(fitLDAgsvd, newdata = datX) +# mean(predictionsLDAgsvd == response) # Training error ```