Skip to content

Commit

Permalink
Quantile predictions output constructor (#1191)
Browse files Browse the repository at this point in the history
* small change to predict checks

* add vctrs for quantiles and test, refactor *_rq_preds

* revise tests

* Apply some of the suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* rename tests on suggestion from code review

* export missing funs from vctrs for formatting

* convert errors to snapshot tests

* pass call through input check

* update snapshots for caller_env

* rename to parsnip_quantiles, add format snapshot tests

* Apply suggestions from @topepo

Co-authored-by: Max Kuhn <[email protected]>

* rename parsnip_quantiles to quantile_pred

* rename parsnip_quantiles to quantile_pred and add vector probability check

* fix: two bugs introduced earlier

* add formatting tests for single quantile

* replace walk with a loop to avoid "Error in map()"

* remove row/col names

* adjust quantile_pred format

* as_tibble method

* updated NEWS file

* add PR number

* small new update

* helper methods

* update docs

* re-enable quantiles prediction for #1203

* update some tests

* no longer needed

* use tibble::new_tibble

* braces

* test as_tibble

* remove print methods

---------

Co-authored-by: Simon P. Couch <[email protected]>
Co-authored-by: Max Kuhn <[email protected]>
Co-authored-by: ‘topepo’ <‘[email protected]’>
  • Loading branch information
4 people authored Sep 13, 2024
1 parent 6168556 commit 3bdb471
Show file tree
Hide file tree
Showing 13 changed files with 551 additions and 70 deletions.
16 changes: 16 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method(.censoring_weights_graf,default)
S3method(.censoring_weights_graf,model_fit)
S3method(as.matrix,quantile_pred)
S3method(as_tibble,quantile_pred)
S3method(augment,model_fit)
S3method(autoplot,glmnet)
S3method(autoplot,model_fit)
Expand Down Expand Up @@ -36,10 +38,12 @@ S3method(extract_spec_parsnip,model_fit)
S3method(fit,model_spec)
S3method(fit_xy,gen_additive_mod)
S3method(fit_xy,model_spec)
S3method(format,quantile_pred)
S3method(glance,model_fit)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
S3method(has_multi_predict,workflow)
S3method(median,quantile_pred)
S3method(multi_predict,"_C5.0")
S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
Expand All @@ -54,6 +58,7 @@ S3method(multi_predict_args,default)
S3method(multi_predict_args,model_fit)
S3method(multi_predict_args,workflow)
S3method(nullmodel,default)
S3method(obj_print_footer,quantile_pred)
S3method(predict,"_elnet")
S3method(predict,"_glmnetfit")
S3method(predict,"_lognet")
Expand Down Expand Up @@ -172,6 +177,8 @@ S3method(update,svm_rbf)
S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
S3method(vec_ptype_abbr,quantile_pred)
S3method(vec_ptype_full,quantile_pred)
export("%>%")
export(.censoring_weights_graf)
export(.check_glmnet_penalty_fit)
Expand Down Expand Up @@ -226,6 +233,7 @@ export(extract_fit_engine)
export(extract_fit_time)
export(extract_parameter_dials)
export(extract_parameter_set_dials)
export(extract_quantile_levels)
export(extract_spec_parsnip)
export(find_engine_files)
export(fit)
Expand Down Expand Up @@ -280,6 +288,7 @@ export(new_model_spec)
export(null_model)
export(null_value)
export(nullmodel)
export(obj_print_footer)
export(parsnip_addin)
export(pls)
export(poisson_reg)
Expand Down Expand Up @@ -307,6 +316,7 @@ export(prepare_data)
export(print_model_spec)
export(prompt_missing_implementation)
export(proportional_hazards)
export(quantile_pred)
export(rand_forest)
export(repair_call)
export(req_pkgs)
Expand Down Expand Up @@ -350,6 +360,8 @@ export(update_model_info_file)
export(update_spec)
export(varying)
export(varying_args)
export(vec_ptype_abbr)
export(vec_ptype_full)
export(xgb_predict)
export(xgb_train)
import(rlang)
Expand Down Expand Up @@ -402,6 +414,7 @@ importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,model.offset)
Expand All @@ -426,5 +439,8 @@ importFrom(utils,globalVariables)
importFrom(utils,head)
importFrom(utils,methods)
importFrom(utils,stack)
importFrom(vctrs,obj_print_footer)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_size)
importFrom(vctrs,vec_unique)
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# parsnip (development version)


* A new model mode (`"quantile regression"`) was added. Including:
* A function to create a new vector class called `quantile_pred()` was added (#1191).
* A `linear_reg()` engine for `"quantreg"`.

* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).

* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
Expand Down
229 changes: 204 additions & 25 deletions R/aaa_quantiles.R
Original file line number Diff line number Diff line change
@@ -1,43 +1,222 @@
# Helpers for quantile regression models

check_quantile_level <- function(x, object, call) {
if ( object$mode != "quantile regression" ) {
if (object$mode != "quantile regression") {
return(invisible(TRUE))
} else {
if ( is.null(x) ) {
if (is.null(x)) {
cli::cli_abort("In {.fn check_mode}, at least one value of
{.arg quantile_level} must be specified for quantile regression models.")
}
}
if (any(is.na(x))) {
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
call = call)
}
x <- sort(unique(x))
# TODO we need better vectorization here, otherwise we get things like:
# "Error during wrapup: i In index: 2." in the traceback.
res <-
purrr::map(x,
~ check_number_decimal(.x, min = 0, max = 1,
arg = "quantile_level", call = call,
allow_infinite = FALSE)
)
check_vector_probability(x, arg = "quantile_level", call = call)
x
}

# Assumes the columns have the same order as quantile_level
restructure_rq_pred <- function(x, object) {
num_quantiles <- NCOL(x)
if ( num_quantiles == 1L ){
x <- matrix(x, ncol = 1)

# -------------------------------------------------------------------------
# A column vector of quantiles with an attribute

#' @importFrom vctrs vec_ptype_abbr
#' @export
vctrs::vec_ptype_abbr

#' @importFrom vctrs vec_ptype_full
#' @export
vctrs::vec_ptype_full


#' @export
vec_ptype_abbr.quantile_pred <- function(x, ...) {
n_lvls <- length(attr(x, "quantile_levels"))
cli::format_inline("qtl{?s}({n_lvls})")
}

#' @export
vec_ptype_full.quantile_pred <- function(x, ...) "quantiles"

new_quantile_pred <- function(values = list(), quantile_levels = double()) {
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
vctrs::new_vctr(
values, quantile_levels = quantile_levels, class = "quantile_pred"
)
}

#' Create a vector containing sets of quantiles
#'
#' [quantile_pred()] is a special vector class used to efficiently store
#' predictions from a quantile regression model. It requires the same quantile
#' levels for each row being predicted.
#'
#' @param values A matrix of values. Each column should correspond to one of
#' the quantile levels.
#' @param quantile_levels A vector of probabilities corresponding to `values`.
#' @param x An object produced by [quantile_pred()].
#' @param .rows,.name_repair,rownames Arguments not used but required by the
#' original S3 method.
#' @param ... Not currently used.
#'
#' @export
#' @return
#' * [quantile_pred()] returns a vector of values associated with the
#' quantile levels.
#' * [extract_quantile_levels()] returns a numeric vector of levels.
#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`,
#' `".quantile_levels"`, and `".row"`.
#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as
#' quantile levels, and entries are predictions.
#' @examples
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
#'
#' unclass(.pred_quantile)
#'
#' # Access the underlying information
#' extract_quantile_levels(.pred_quantile)
#'
#' # Matrix format
#' as.matrix(.pred_quantile)
#'
#' # Tidy format
#' tibble::as_tibble(.pred_quantile)
quantile_pred <- function(values, quantile_levels = double()) {
check_quantile_pred_inputs(values, quantile_levels)

quantile_levels <- vctrs::vec_cast(quantile_levels, double())
num_lvls <- length(quantile_levels)

if (ncol(values) != num_lvls) {
cli::cli_abort(
"The number of columns in {.arg values} must be equal to the length of
{.arg quantile_levels}."
)
}
rownames(values) <- NULL
colnames(values) <- NULL
values <- lapply(vctrs::vec_chop(values), drop)
new_quantile_pred(values, quantile_levels)
}

check_quantile_pred_inputs <- function(values, levels, call = caller_env()) {
if (any(is.na(levels))) {
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
call = call)
}
n <- nrow(x)

if (!is.matrix(values)) {
cli::cli_abort(
"{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.",
call = call
)
}
check_vector_probability(levels, arg = "quantile_levels", call = call)

if (is.unsorted(levels)) {
cli::cli_abort(
"{.arg quantile_levels} must be sorted in increasing order.",
call = call
)
}
invisible(NULL)
}

#' @export
format.quantile_pred <- function(x, ...) {
quantile_levels <- attr(x, "quantile_levels")
if (length(quantile_levels) == 1L) {
x <- unlist(x)
out <- round(x, 3L)
out[is.na(x)] <- NA_real_
} else {
rng <- sapply(x, range, na.rm = TRUE)
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_
m <- median(x)
out <- paste0("[", round(m, 3L), "]")
}
out
}

#' @importFrom vctrs obj_print_footer
#' @export
vctrs::obj_print_footer

#' @export
obj_print_footer.quantile_pred <- function(x, digits = 3, ...) {
lvls <- attr(x, "quantile_levels")
cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ")
}

check_vector_probability <- function(x, ...,
allow_na = FALSE,
allow_null = FALSE,
arg = caller_arg(x),
call = caller_env()) {
for (d in x) {
check_number_decimal(
d, min = 0, max = 1,
arg = arg, call = call,
allow_na = allow_na,
allow_null = allow_null,
allow_infinite = FALSE
)
}
}

#' @export
median.quantile_pred <- function(x, ...) {
lvls <- attr(x, "quantile_levels")
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps))
if (any(loc_median)) {
return(map_dbl(x, ~ .x[min(which(loc_median))]))
}
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) {
return(rep(NA, vctrs::vec_size(x)))
}
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y)
}

restructure_rq_pred <- function(x, object) {
if (!is.matrix(x)) {
x <- as.matrix(x)
}
rownames(x) <- NULL
n_pred_quantiles <- ncol(x)
quantile_level <- object$spec$quantile_level
res <-
tibble::tibble(
.pred_quantile = as.vector(x),
.quantile_level = rep(quantile_level, each = n),
.row = rep(1:n, num_quantiles))
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
res$.row <- NULL
res

tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level)))
}

#' @export
#' @rdname quantile_pred
extract_quantile_levels <- function(x) {
if (!inherits(x, "quantile_pred")) {
cli::cli_abort("{.arg x} should have class {.val quantile_pred}.")
}
attr(x, "quantile_levels")
}

#' @export
#' @rdname quantile_pred
as_tibble.quantile_pred <-
function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) {
lvls <- attr(x, "quantile_levels")
n_samp <- length(x)
n_quant <- length(lvls)
tibble::tibble(
.pred_quantile = unlist(x),
.quantile_levels = rep(lvls, n_samp),
.row = rep(1:n_samp, each = n_quant)
)
}

#' @export
#' @rdname quantile_pred
as.matrix.quantile_pred <- function(x, ...) {
num_samp <- length(x)
matrix(unlist(x), nrow = num_samp)
}
2 changes: 1 addition & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#' @importFrom stats .checkMFClasses .getXlevels as.formula binomial coef
#' @importFrom stats delete.response model.frame model.matrix model.offset
#' @importFrom stats model.response model.weights na.omit na.pass predict qnorm
#' @importFrom stats qt quantile setNames terms update
#' @importFrom stats qt quantile setNames terms update median
#' @importFrom tibble as_tibble is_tibble tibble
#' @importFrom tidyr gather
#' @importFrom utils capture.output getFromNamespace globalVariables head
Expand Down
12 changes: 9 additions & 3 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <- function(object, new_data, ...) {
predict_quantile.model_fit <- function(object,
new_data,
quantile = (1:9)/10,
interval = "none",
level = 0.95,
...) {

check_spec_pred_type(object, "quantile")

Expand All @@ -23,7 +28,7 @@ predict_quantile.model_fit <- function(object, new_data, ...) {
}

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level
object$spec$method$pred$quantile$args$p <- quantile
pred_call <- make_pred_call(object$spec$method$pred$quantile)

res <- eval_tidy(pred_call)
Expand All @@ -40,5 +45,6 @@ predict_quantile.model_fit <- function(object, new_data, ...) {
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_quantile <- function (object, ...)
predict_quantile <- function (object, ...) {
UseMethod("predict_quantile")
}
Loading

0 comments on commit 3bdb471

Please sign in to comment.