-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Quantile predictions output constructor (#1191)
* small change to predict checks * add vctrs for quantiles and test, refactor *_rq_preds * revise tests * Apply some of the suggestions from code review Co-authored-by: Simon P. Couch <[email protected]> * rename tests on suggestion from code review * export missing funs from vctrs for formatting * convert errors to snapshot tests * pass call through input check * update snapshots for caller_env * rename to parsnip_quantiles, add format snapshot tests * Apply suggestions from @topepo Co-authored-by: Max Kuhn <[email protected]> * rename parsnip_quantiles to quantile_pred * rename parsnip_quantiles to quantile_pred and add vector probability check * fix: two bugs introduced earlier * add formatting tests for single quantile * replace walk with a loop to avoid "Error in map()" * remove row/col names * adjust quantile_pred format * as_tibble method * updated NEWS file * add PR number * small new update * helper methods * update docs * re-enable quantiles prediction for #1203 * update some tests * no longer needed * use tibble::new_tibble * braces * test as_tibble * remove print methods --------- Co-authored-by: Simon P. Couch <[email protected]> Co-authored-by: Max Kuhn <[email protected]> Co-authored-by: ‘topepo’ <‘[email protected]’>
- Loading branch information
1 parent
6168556
commit 3bdb471
Showing
13 changed files
with
551 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,222 @@ | ||
# Helpers for quantile regression models | ||
|
||
check_quantile_level <- function(x, object, call) { | ||
if ( object$mode != "quantile regression" ) { | ||
if (object$mode != "quantile regression") { | ||
return(invisible(TRUE)) | ||
} else { | ||
if ( is.null(x) ) { | ||
if (is.null(x)) { | ||
cli::cli_abort("In {.fn check_mode}, at least one value of | ||
{.arg quantile_level} must be specified for quantile regression models.") | ||
} | ||
} | ||
if (any(is.na(x))) { | ||
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", | ||
call = call) | ||
} | ||
x <- sort(unique(x)) | ||
# TODO we need better vectorization here, otherwise we get things like: | ||
# "Error during wrapup: i In index: 2." in the traceback. | ||
res <- | ||
purrr::map(x, | ||
~ check_number_decimal(.x, min = 0, max = 1, | ||
arg = "quantile_level", call = call, | ||
allow_infinite = FALSE) | ||
) | ||
check_vector_probability(x, arg = "quantile_level", call = call) | ||
x | ||
} | ||
|
||
# Assumes the columns have the same order as quantile_level | ||
restructure_rq_pred <- function(x, object) { | ||
num_quantiles <- NCOL(x) | ||
if ( num_quantiles == 1L ){ | ||
x <- matrix(x, ncol = 1) | ||
|
||
# ------------------------------------------------------------------------- | ||
# A column vector of quantiles with an attribute | ||
|
||
#' @importFrom vctrs vec_ptype_abbr | ||
#' @export | ||
vctrs::vec_ptype_abbr | ||
|
||
#' @importFrom vctrs vec_ptype_full | ||
#' @export | ||
vctrs::vec_ptype_full | ||
|
||
|
||
#' @export | ||
vec_ptype_abbr.quantile_pred <- function(x, ...) { | ||
n_lvls <- length(attr(x, "quantile_levels")) | ||
cli::format_inline("qtl{?s}({n_lvls})") | ||
} | ||
|
||
#' @export | ||
vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" | ||
|
||
new_quantile_pred <- function(values = list(), quantile_levels = double()) { | ||
quantile_levels <- vctrs::vec_cast(quantile_levels, double()) | ||
vctrs::new_vctr( | ||
values, quantile_levels = quantile_levels, class = "quantile_pred" | ||
) | ||
} | ||
|
||
#' Create a vector containing sets of quantiles | ||
#' | ||
#' [quantile_pred()] is a special vector class used to efficiently store | ||
#' predictions from a quantile regression model. It requires the same quantile | ||
#' levels for each row being predicted. | ||
#' | ||
#' @param values A matrix of values. Each column should correspond to one of | ||
#' the quantile levels. | ||
#' @param quantile_levels A vector of probabilities corresponding to `values`. | ||
#' @param x An object produced by [quantile_pred()]. | ||
#' @param .rows,.name_repair,rownames Arguments not used but required by the | ||
#' original S3 method. | ||
#' @param ... Not currently used. | ||
#' | ||
#' @export | ||
#' @return | ||
#' * [quantile_pred()] returns a vector of values associated with the | ||
#' quantile levels. | ||
#' * [extract_quantile_levels()] returns a numeric vector of levels. | ||
#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`, | ||
#' `".quantile_levels"`, and `".row"`. | ||
#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as | ||
#' quantile levels, and entries are predictions. | ||
#' @examples | ||
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) | ||
#' | ||
#' unclass(.pred_quantile) | ||
#' | ||
#' # Access the underlying information | ||
#' extract_quantile_levels(.pred_quantile) | ||
#' | ||
#' # Matrix format | ||
#' as.matrix(.pred_quantile) | ||
#' | ||
#' # Tidy format | ||
#' tibble::as_tibble(.pred_quantile) | ||
quantile_pred <- function(values, quantile_levels = double()) { | ||
check_quantile_pred_inputs(values, quantile_levels) | ||
|
||
quantile_levels <- vctrs::vec_cast(quantile_levels, double()) | ||
num_lvls <- length(quantile_levels) | ||
|
||
if (ncol(values) != num_lvls) { | ||
cli::cli_abort( | ||
"The number of columns in {.arg values} must be equal to the length of | ||
{.arg quantile_levels}." | ||
) | ||
} | ||
rownames(values) <- NULL | ||
colnames(values) <- NULL | ||
values <- lapply(vctrs::vec_chop(values), drop) | ||
new_quantile_pred(values, quantile_levels) | ||
} | ||
|
||
check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { | ||
if (any(is.na(levels))) { | ||
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", | ||
call = call) | ||
} | ||
n <- nrow(x) | ||
|
||
if (!is.matrix(values)) { | ||
cli::cli_abort( | ||
"{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", | ||
call = call | ||
) | ||
} | ||
check_vector_probability(levels, arg = "quantile_levels", call = call) | ||
|
||
if (is.unsorted(levels)) { | ||
cli::cli_abort( | ||
"{.arg quantile_levels} must be sorted in increasing order.", | ||
call = call | ||
) | ||
} | ||
invisible(NULL) | ||
} | ||
|
||
#' @export | ||
format.quantile_pred <- function(x, ...) { | ||
quantile_levels <- attr(x, "quantile_levels") | ||
if (length(quantile_levels) == 1L) { | ||
x <- unlist(x) | ||
out <- round(x, 3L) | ||
out[is.na(x)] <- NA_real_ | ||
} else { | ||
rng <- sapply(x, range, na.rm = TRUE) | ||
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") | ||
out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_ | ||
m <- median(x) | ||
out <- paste0("[", round(m, 3L), "]") | ||
} | ||
out | ||
} | ||
|
||
#' @importFrom vctrs obj_print_footer | ||
#' @export | ||
vctrs::obj_print_footer | ||
|
||
#' @export | ||
obj_print_footer.quantile_pred <- function(x, digits = 3, ...) { | ||
lvls <- attr(x, "quantile_levels") | ||
cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ") | ||
} | ||
|
||
check_vector_probability <- function(x, ..., | ||
allow_na = FALSE, | ||
allow_null = FALSE, | ||
arg = caller_arg(x), | ||
call = caller_env()) { | ||
for (d in x) { | ||
check_number_decimal( | ||
d, min = 0, max = 1, | ||
arg = arg, call = call, | ||
allow_na = allow_na, | ||
allow_null = allow_null, | ||
allow_infinite = FALSE | ||
) | ||
} | ||
} | ||
|
||
#' @export | ||
median.quantile_pred <- function(x, ...) { | ||
lvls <- attr(x, "quantile_levels") | ||
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps)) | ||
if (any(loc_median)) { | ||
return(map_dbl(x, ~ .x[min(which(loc_median))])) | ||
} | ||
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) { | ||
return(rep(NA, vctrs::vec_size(x))) | ||
} | ||
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y) | ||
} | ||
|
||
restructure_rq_pred <- function(x, object) { | ||
if (!is.matrix(x)) { | ||
x <- as.matrix(x) | ||
} | ||
rownames(x) <- NULL | ||
n_pred_quantiles <- ncol(x) | ||
quantile_level <- object$spec$quantile_level | ||
res <- | ||
tibble::tibble( | ||
.pred_quantile = as.vector(x), | ||
.quantile_level = rep(quantile_level, each = n), | ||
.row = rep(1:n, num_quantiles)) | ||
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"]) | ||
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val))) | ||
res$.row <- NULL | ||
res | ||
|
||
tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level))) | ||
} | ||
|
||
#' @export | ||
#' @rdname quantile_pred | ||
extract_quantile_levels <- function(x) { | ||
if (!inherits(x, "quantile_pred")) { | ||
cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") | ||
} | ||
attr(x, "quantile_levels") | ||
} | ||
|
||
#' @export | ||
#' @rdname quantile_pred | ||
as_tibble.quantile_pred <- | ||
function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { | ||
lvls <- attr(x, "quantile_levels") | ||
n_samp <- length(x) | ||
n_quant <- length(lvls) | ||
tibble::tibble( | ||
.pred_quantile = unlist(x), | ||
.quantile_levels = rep(lvls, n_samp), | ||
.row = rep(1:n_samp, each = n_quant) | ||
) | ||
} | ||
|
||
#' @export | ||
#' @rdname quantile_pred | ||
as.matrix.quantile_pred <- function(x, ...) { | ||
num_samp <- length(x) | ||
matrix(unlist(x), nrow = num_samp) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.