diff --git a/R/plot_pred.R b/R/plot_pred.R index f1484a87..64e1695b 100644 --- a/R/plot_pred.R +++ b/R/plot_pred.R @@ -1,6 +1,6 @@ #' Plot the predictor matrix of an imputation model #' -#' @param data A predictor matrix for `mice`, typically generated with [mice::make.predictorMatrix] or [mice::quickpred]. +#' @param data A predictor matrix for `mice`, typically generated with [mice::make.predictorMatrix] or [mice::quickpred], or an object of class [`mice::mids`]. #' @param vrb String, vector, or unquoted expression with variable name(s), default is "all". #' @param method Character string or vector with imputation methods. #' @param label Logical indicating whether predictor matrix values should be displayed. @@ -20,7 +20,17 @@ plot_pred <- label = TRUE, square = TRUE, rotate = FALSE) { - verify_data(data, pred = TRUE) + verify_data(data, pred = TRUE, imp = TRUE) + if (mice::is.mids(data)) { + if (!is.null(method)) { + cli::cli_warn( + c("!" = "Input `method` is ignored when `data` is of class `mids`.", + "i" = "The `method` vector from the `mids` object will be used.") + ) + } + method <- data$method + data <- data$predictorMatrix + } p <- nrow(data) if (!is.null(method) && is.character(method)) { if (length(method) == 1) { @@ -35,7 +45,8 @@ plot_pred <- ylabel <- "" } if (!is.character(method) || length(method) != p) { - cli::cli_abort("Method should be NULL or a character string or vector (of length 1 or `ncol(data)`).") + cli::cli_abort("Method should be `NULL` or a character string or vector + (of length 1 or `ncol(data)`).") } vrb <- substitute(vrb) if (vrb[1] == "all") { diff --git a/R/utils.R b/R/utils.R index 2cb42595..1edbc6d1 100644 --- a/R/utils.R +++ b/R/utils.R @@ -57,7 +57,7 @@ verify_data <- function(data, ) } } - if (imp && !df) { + if (imp && !df && !pred) { if (!mice::is.mids(data)) { cli::cli_abort( c( @@ -68,7 +68,18 @@ verify_data <- function(data, ) } } - if (pred) { + if (imp && pred){ + if (!(is.matrix(data) || mice::is.mids(data))) { + cli::cli_abort( + c( + "The 'data' argument requires an object of class 'matrix', or 'mids'.", + "i" = "Input object is of class {class(data)}." + ), + call. = FALSE + ) + } + } + if (pred && !imp) { if (!is.matrix(data)) { cli::cli_abort( c( diff --git a/man/plot_pred.Rd b/man/plot_pred.Rd index c71970a6..2d5c98f9 100644 --- a/man/plot_pred.Rd +++ b/man/plot_pred.Rd @@ -14,7 +14,7 @@ plot_pred( ) } \arguments{ -\item{data}{A predictor matrix for \code{mice}, typically generated with \link[mice:make.predictorMatrix]{mice::make.predictorMatrix} or \link[mice:quickpred]{mice::quickpred}.} +\item{data}{A predictor matrix for \code{mice}, typically generated with \link[mice:make.predictorMatrix]{mice::make.predictorMatrix} or \link[mice:quickpred]{mice::quickpred}, or an object of class \code{\link[mice:mids-class]{mice::mids}}.} \item{vrb}{String, vector, or unquoted expression with variable name(s), default is "all".} diff --git a/tests/testthat/test-plot_pred.R b/tests/testthat/test-plot_pred.R index 57d72649..ff0c9aa1 100644 --- a/tests/testthat/test-plot_pred.R +++ b/tests/testthat/test-plot_pred.R @@ -1,6 +1,7 @@ # create test objects dat <- mice::nhanes pred <- mice::quickpred(dat) +imp <- mice::mice(dat, printFlag = FALSE) # tests test_that("plot_pred creates ggplot object", { @@ -18,6 +19,7 @@ test_that("plot_pred creates ggplot object", { expect_s3_class(plot_pred(rbind( cbind(pred, "with space" = 0), "with space" = 0 )), "ggplot") + expect_s3_class(plot_pred(imp, vrb = c("age", "bmi")), "ggplot") }) test_that("plot_pred with incorrect argument(s)", {