diff --git a/R-package/R/xgb.ggplot.R b/R-package/R/xgb.ggplot.R index 3e2e6e8e9603..f5a4d5509987 100644 --- a/R-package/R/xgb.ggplot.R +++ b/R-package/R/xgb.ggplot.R @@ -102,6 +102,27 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med #' @export xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL, trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) { + if (inherits(data, "xgb.DMatrix")) { + stop( + "'xgb.ggplot.shap.summary' is not compatible with 'xgb.DMatrix' objects. Try passing a matrix or data.frame." + ) + } + cols_categ <- NULL + if (!is.null(model)) { + ftypes <- getinfo(model, "feature_type") + if (NROW(ftypes)) { + if (length(ftypes) != ncol(data)) { + stop(sprintf("'data' has incorrect number of columns (expected: %d, got: %d).", length(ftypes), ncol(data))) + } + cols_categ <- colnames(data)[ftypes == "c"] + } + } else if (inherits(data, "data.frame")) { + cols_categ <- names(data)[sapply(data, function(x) is.factor(x) || is.character(x))] + } + if (NROW(cols_categ)) { + warning("Categorical features are ignored in 'xgb.ggplot.shap.summary'.") + } + data_list <- xgb.shap.data( data = data, shap_contrib = shap_contrib, @@ -114,6 +135,10 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, subsample = subsample, max_observations = 10000 # 10,000 samples per feature. ) + if (NROW(cols_categ)) { + data_list <- lapply(data_list, function(x) x[, !(colnames(x) %in% cols_categ), drop = FALSE]) + } + p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE) # Reverse factor levels so that the first level is at the top of the plot p_data[, "feature" := factor(feature, rev(levels(feature)))] @@ -134,7 +159,8 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, #' @param data_list The result of `xgb.shap.data()`. #' @param normalize Whether to standardize feature values to mean 0 and #' standard deviation 1. This is useful for comparing multiple features on the same -#' plot. Default is `FALSE`. +#' plot. Default is `FALSE`. Note that it cannot be used when the data contains +#' categorical features. #' @return A `data.table` containing the observation ID, the feature name, the #' feature value (normalized if specified), and the SHAP contribution value. #' @noRd diff --git a/R-package/R/xgb.plot.shap.R b/R-package/R/xgb.plot.shap.R index 79c2ed328a7a..443020e1ac7e 100644 --- a/R-package/R/xgb.plot.shap.R +++ b/R-package/R/xgb.plot.shap.R @@ -2,7 +2,7 @@ #' #' Visualizes SHAP values against feature values to gain an impression of feature effects. #' -#' @param data The data to explain as a `matrix` or `dgCMatrix`. +#' @param data The data to explain as a `matrix`, `dgCMatrix`, or `data.frame`. #' @param shap_contrib Matrix of SHAP contributions of `data`. #' The default (`NULL`) computes it from `model` and `data`. #' @param features Vector of column indices or feature names to plot. When `NULL` @@ -285,8 +285,11 @@ xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, to xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL, trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL, max_observations = 100000) { - if (!is.matrix(data) && !inherits(data, "dgCMatrix")) - stop("data: must be either matrix or dgCMatrix") + if (!inherits(data, c("matrix", "dsparseMatrix", "data.frame"))) + stop("data: must be matrix, sparse matrix, or data.frame.") + if (inherits(data, "data.frame") && length(class(data)) > 1L) { + data <- as.data.frame(data) + } if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster"))) stop("when shap_contrib is not provided, one must provide an xgb.Booster model") @@ -311,7 +314,14 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, stop("if model has no feature_names, columns in `data` must match features in model") if (!is.null(subsample)) { - idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE) + if (subsample <= 0 || subsample >= 1) { + stop("'subsample' must be a number between zero and one (non-inclusive).") + } + sample_size <- as.integer(subsample * nrow(data)) + if (sample_size < 2) { + stop("Sampling fraction involves less than 2 rows.") + } + idx <- sample(x = seq_len(nrow(data)), size = sample_size, replace = FALSE) } else { idx <- seq_len(min(nrow(data), max_observations)) } diff --git a/R-package/man/xgb.plot.shap.Rd b/R-package/man/xgb.plot.shap.Rd index c94fb2bb34c4..f4f51059d653 100644 --- a/R-package/man/xgb.plot.shap.Rd +++ b/R-package/man/xgb.plot.shap.Rd @@ -33,7 +33,7 @@ xgb.plot.shap( ) } \arguments{ -\item{data}{The data to explain as a \code{matrix} or \code{dgCMatrix}.} +\item{data}{The data to explain as a \code{matrix}, \code{dgCMatrix}, or \code{data.frame}.} \item{shap_contrib}{Matrix of SHAP contributions of \code{data}. The default (\code{NULL}) computes it from \code{model} and \code{data}.} diff --git a/R-package/man/xgb.plot.shap.summary.Rd b/R-package/man/xgb.plot.shap.summary.Rd index 7fbca6fd9c10..f6df2daca758 100644 --- a/R-package/man/xgb.plot.shap.summary.Rd +++ b/R-package/man/xgb.plot.shap.summary.Rd @@ -30,7 +30,7 @@ xgb.plot.shap.summary( ) } \arguments{ -\item{data}{The data to explain as a \code{matrix} or \code{dgCMatrix}.} +\item{data}{The data to explain as a \code{matrix}, \code{dgCMatrix}, or \code{data.frame}.} \item{shap_contrib}{Matrix of SHAP contributions of \code{data}. The default (\code{NULL}) computes it from \code{model} and \code{data}.} diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 7724d6bc5da6..dcaf4f2fd4c4 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -449,6 +449,26 @@ test_that("xgb.shap.data works with subsampling", { expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib)) }) +test_that("xgb.shap.data works with data frames", { + data(mtcars) + df <- mtcars + df$cyl <- factor(df$cyl) + x <- df[, -1] + y <- df$mpg + dm <- xgb.DMatrix(x, label = y, nthread = 1L) + model <- xgb.train( + data = dm, + params = list( + max_depth = 2, + nthread = 1 + ), + nrounds = 2 + ) + data_list <- xgb.shap.data(data = df[, -1], model = model, top_n = 2, subsample = 0.8) + expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(df))) + expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib)) +}) + test_that("prepare.ggplot.shap.data works", { .skip_if_vcd_not_available() data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2) @@ -472,6 +492,44 @@ test_that("xgb.plot.shap.summary works", { expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)) }) +test_that("xgb.plot.shap.summary ignores categorical features", { + .skip_if_vcd_not_available() + data(mtcars) + df <- mtcars + df$cyl <- factor(df$cyl) + levels(df$cyl) <- c("a", "b", "c") + x <- df[, -1] + y <- df$mpg + dm <- xgb.DMatrix(x, label = y, nthread = 1L) + model <- xgb.train( + data = dm, + params = list( + max_depth = 2, + nthread = 1 + ), + nrounds = 2 + ) + expect_warning({ + xgb.ggplot.shap.summary(data = x, model = model, top_n = 2) + }) + + x_num <- mtcars[, -1] + x_num$gear <- as.numeric(x_num$gear) - 1 + x_num <- as.matrix(x_num) + dm <- xgb.DMatrix(x_num, label = y, feature_types = c(rep("q", 8), "c", "q"), nthread = 1L) + model <- xgb.train( + data = dm, + params = list( + max_depth = 2, + nthread = 1 + ), + nrounds = 2 + ) + expect_warning({ + xgb.ggplot.shap.summary(data = x_num, model = model, top_n = 2) + }) +}) + test_that("check.deprecation works", { ttt <- function(a = NNULL, DUMMY = NULL, ...) { check.deprecation(...)