diff --git a/DESCRIPTION b/DESCRIPTION index bbc4856..747fa76 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -6,18 +6,28 @@ Authors@R: c( person("Hannah", "Frick", , "hannah@posit.co", role = "aut"), person("Emil", "HvitFeldt", , "emil.hvitfeldt@posit.co", role = "aut"), person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), - person(given = "Posit Software, PBC", role = c("cph", "fnd")) + person("Posit Software, PBC", role = c("cph", "fnd")) ) Description: Sandbox for a postprocessor object. License: MIT + file LICENSE +URL: https://github.com/tidymodels/container +BugReports: https://github.com/tidymodels/container/issues +Imports: + cli, + dplyr, + generics, + hardhat, + probably (>= 1.0.3.9000), + purrr, + rlang (>= 1.1.0), + tibble, + tidyselect Suggests: + modeldata, testthat (>= 3.0.0) +Remotes: + tidymodels/probably Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.1 -URL: https://github.com/tidymodels/container -BugReports: https://github.com/tidymodels/container/issues -Imports: - cli, - rlang (>= 1.1.0) diff --git a/NAMESPACE b/NAMESPACE index 20159f1..80050c2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,63 @@ # Generated by roxygen2: do not edit by hand +S3method(fit,container) +S3method(fit,equivocal_zone) +S3method(fit,numeric_calibration) +S3method(fit,numeric_range) +S3method(fit,predictions_custom) +S3method(fit,probability_calibration) +S3method(fit,probability_threshold) +S3method(predict,container) +S3method(predict,equivocal_zone) +S3method(predict,numeric_calibration) +S3method(predict,numeric_range) +S3method(predict,predictions_custom) +S3method(predict,probability_calibration) +S3method(predict,probability_threshold) +S3method(print,container) +S3method(print,equivocal_zone) +S3method(print,numeric_calibration) +S3method(print,numeric_range) +S3method(print,predictions_custom) +S3method(print,probability_calibration) +S3method(print,probability_threshold) +S3method(required_pkgs,equivocal_zone) +S3method(required_pkgs,numeric_calibration) +S3method(required_pkgs,numeric_range) +S3method(required_pkgs,predictions_custom) +S3method(required_pkgs,probability_calibration) +S3method(required_pkgs,probability_threshold) +S3method(tunable,equivocal_zone) +S3method(tunable,numeric_calibration) +S3method(tunable,numeric_range) +S3method(tunable,predictions_custom) +S3method(tunable,probability_calibration) +S3method(tunable,probability_threshold) +export("%>%") +export(adjust_equivocal_zone) +export(adjust_numeric_calibration) +export(adjust_numeric_range) +export(adjust_predictions_custom) +export(adjust_probability_calibration) +export(adjust_probability_threshold) +export(container) +export(extract_parameter_dials) +export(extract_parameter_set_dials) +export(fit) +export(required_pkgs) +export(tidy) +export(tunable) +export(tune_args) import(rlang) importFrom(cli,cli_abort) importFrom(cli,cli_inform) importFrom(cli,cli_warn) +importFrom(dplyr,"%>%") +importFrom(generics,fit) +importFrom(generics,required_pkgs) +importFrom(generics,tidy) +importFrom(generics,tunable) +importFrom(generics,tune_args) +importFrom(hardhat,extract_parameter_dials) +importFrom(hardhat,extract_parameter_set_dials) +importFrom(stats,predict) diff --git a/R/adjust-equivocal-zone.R b/R/adjust-equivocal-zone.R new file mode 100644 index 0000000..6df78b8 --- /dev/null +++ b/R/adjust-equivocal-zone.R @@ -0,0 +1,118 @@ +#' Apply an equivocal zone to a binary classification model. +#' +#' @param x A [container()]. +#' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The +#' value is the size of the buffer around the threshold. +#' @param threshold A numeric value (between zero and one) or [hardhat::tune()]. +#' @examples +#' library(dplyr) +#' library(modeldata) +#' +#' post_obj <- +#' container(mode = "classification") %>% +#' adjust_equivocal_zone(value = 1 / 4) +#' +#' +#' post_res <- fit( +#' post_obj, +#' two_class_example, +#' outcome = c(truth), +#' estimate = c(predicted), +#' probabilities = c(Class1, Class2) +#' ) +#' +#' predict(post_res, two_class_example) +#' @export +adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) { + check_container(x) + if (!is_tune(value)) { + check_number_decimal(value, min = 0, max = 1 / 2) + } + if (!is_tune(threshold)) { + check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10) + } + + op <- + new_operation( + "equivocal_zone", + inputs = "probability", + outputs = "class", + arguments = list(value = value, threshold = threshold), + results = list(), + trained = FALSE + ) + + new_container( + mode = x$mode, + type = x$type, + operations = c(x$operations, list(op)), + columns = x$dat, + ptype = x$ptype, + call = current_env() + ) +} + +#' @export +print.equivocal_zone <- function(x, ...) { + # check for tune() first + + if (is_tune(x$arguments$value)) { + cli::cli_bullets(c("*" = "Add equivocal zone of optimized size.")) + } else { + trn <- ifelse(x$trained, " [trained]", "") + cli::cli_bullets(c( + "*" = "Add equivocal zone of size + {signif(x$arguments$value, digits = 3)}.{trn}" + )) + } + invisible(x) +} + +#' @export +fit.equivocal_zone <- function(object, data, container = NULL, ...) { + new_operation( + class(object), + inputs = object$inputs, + outputs = object$outputs, + arguments = object$arguments, + results = list(), + trained = TRUE + ) +} + +#' @export +predict.equivocal_zone <- function(object, new_data, container, ...) { + est_nm <- container$columns$estimate + prob_nm <- container$columns$probabilities[1] + lvls <- levels(new_data[[est_nm]]) + col_syms <- syms(prob_nm[1]) + cls_pred <- probably::make_two_class_pred( + new_data[[prob_nm]], + levels = lvls, + buffer = object$arguments$value, + threshold = object$arguments$threshold + ) + new_data[[est_nm]] <- cls_pred # todo convert to factor? + new_data +} + +#' @export +required_pkgs.equivocal_zone <- function(x, ...) { + c("container", "probably") +} + +#' @export +tunable.equivocal_zone <- function(x, ...) { + tibble::new_tibble(list( + name = "buffer", + call_info = list(list(pkg = "dials", fun = "buffer")), + source = "container", + component = "equivocal_zone", + component_id = "equivocal_zone" + )) +} + +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/adjust-numeric-calibration.R b/R/adjust-numeric-calibration.R new file mode 100644 index 0000000..10a9ec5 --- /dev/null +++ b/R/adjust-numeric-calibration.R @@ -0,0 +1,99 @@ +#' Re-calibrate numeric predictions +#' +#' @param x A [container()]. +#' @param calibrator A pre-trained calibration method from the \pkg{probably} +#' package, such as [probably::cal_estimate_linear()]. +#' @examples +#' library(modeldata) +#' library(probably) +#' library(tibble) +#' +#' # create example data +#' set.seed(1) +#' dat <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100)) +#' +#' dat +#' +#' # calibrate numeric predictions +#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred) +#' +#' # specify calibration +#' reg_ctr <- +#' container(mode = "regression") %>% +#' adjust_numeric_calibration(reg_cal) +#' +#' # "train" container +#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred) +#' +#' predict(reg_ctr, dat) +#' @export +adjust_numeric_calibration <- function(x, calibrator) { + check_container(x) + check_required(calibrator) + if (!inherits(calibrator, "cal_regression")) { + cli_abort( + "{.arg calibrator} should be a \\ + {.help [ object](probably::cal_estimate_linear)}, \\ + not {.obj_type_friendly {calibrator}}." + ) + } + + op <- + new_operation( + "numeric_calibration", + inputs = "numeric", + outputs = "numeric", + arguments = list(calibrator = calibrator), + results = list(), + trained = FALSE + ) + + new_container( + mode = x$mode, + type = x$type, + operations = c(x$operations, list(op)), + columns = x$dat, + ptype = x$ptype, + call = current_env() + ) +} + +#' @export +print.numeric_calibration <- function(x, ...) { + trn <- ifelse(x$trained, " [trained]", "") + cli::cli_bullets(c("*" = "Re-calibrate numeric predictions.{trn}")) + invisible(x) +} + +#' @export +fit.numeric_calibration <- function(object, data, container = NULL, ...) { + new_operation( + class(object), + inputs = object$inputs, + outputs = object$outputs, + arguments = object$arguments, + results = list(), + trained = TRUE + ) +} + +#' @export +predict.numeric_calibration <- function(object, new_data, container, ...) { + probably::cal_apply(new_data, object$argument$calibrator) +} + +# todo probably needs required_pkgs methods for cal objects +#' @export +required_pkgs.numeric_calibration <- function(x, ...) { + c("container", "probably") +} + +#' @export +tunable.numeric_calibration <- function(x, ...) { + no_param +} + +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/adjust-numeric-range.R b/R/adjust-numeric-range.R new file mode 100644 index 0000000..3a5e21c --- /dev/null +++ b/R/adjust-numeric-range.R @@ -0,0 +1,106 @@ +#' Truncate the range of numeric predictions +#' +#' @param x A [container()]. +#' @param upper_limit,lower_limit A numeric value, NA (for no truncation) or +#' [hardhat::tune()]. +#' @export +adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) { + # remaining input checks are done via probably::bound_prediction + check_container(x) + + op <- + new_operation( + "numeric_range", + inputs = "numeric", + outputs = "numeric", + arguments = list(lower_limit = lower_limit, upper_limit = upper_limit), + results = list(), + trained = FALSE + ) + + new_container( + mode = x$mode, + type = x$type, + operations = c(x$operations, list(op)), + columns = x$dat, + ptype = x$ptype, + call = current_env() + ) +} + +#' @export +print.numeric_range <- function(x, ...) { + trn <- ifelse(x$trained, " [trained]", "") + # todo could be na + if (!is_tune(x$arguments$lower_limit)) { + if (!is_tune(x$arguments$upper_limit)) { + rng_txt <- + paste0( + "between [", + signif(x$arguments$lower_limit, 3), + ", ", + signif(x$arguments$upper_limit, 3), + "]" + ) + } else { + rng_txt <- paste0("between [", signif(x$arguments$lower_limit, 3), ", ?]") + } + } else { + if (!is_tune(x$arguments$upper_limit)) { + rng_txt <- paste0("between [?, ", signif(x$arguments$upper_limit, 3), "]") + } else { + rng_txt <- "between [?, ?]" + } + } + + cli::cli_bullets(c("*" = "Constrain numeric predictions to be {rng_txt}{trn}.")) + invisible(x) +} + +#' @export +fit.numeric_range <- function(object, data, container = NULL, ...) { + new_operation( + class(object), + inputs = object$inputs, + outputs = object$outputs, + arguments = object$arguments, + results = list(), + trained = TRUE + ) +} + +#' @export +predict.numeric_range <- function(object, new_data, container, ...) { + est_nm <- container$columns$estimate + lo <- object$arguments$lower_limit + hi <- object$arguments$upper_limit + + # todo depends on tm predict col names + new_data[[est_nm]] <- + probably::bound_prediction(new_data, lower_limit = lo, upper_limit = hi)[[est_nm]] + new_data +} + +#' @export +required_pkgs.numeric_range <- function(x, ...) { + c("container", "probably") +} + +#' @export +tunable.numeric_range <- function(x, ...) { + tibble::new_tibble(list( + name = c("lower_limit", "upper_limit"), + call_info = list( + list(pkg = "dials", fun = "lower_limit"), # todo make these dials functions + list(pkg = "dials", fun = "upper_limit") + ), + source = "container", + component = "numeric_range", + component_id = "numeric_range" + )) +} + +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/adjust-predictions-custom.R b/R/adjust-predictions-custom.R new file mode 100644 index 0000000..8b61ab9 --- /dev/null +++ b/R/adjust-predictions-custom.R @@ -0,0 +1,88 @@ +#' Change or add variables +#' +#' @param x A [container()]. +#' @param .pkgs A character string of extra packages that are needed to execute +#' the commands. +#' @param ... Name-value pairs of expressions. See [dplyr::mutate()]. +#' @examples +#' library(dplyr) +#' library(modeldata) +#' +#' post_obj <- +#' container(mode = "classification") %>% +#' adjust_equivocal_zone() %>% +#' adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2)) +#' +#' +#' post_res <- fit( +#' post_obj, +#' two_class_example, +#' outcome = c(truth), +#' estimate = c(predicted), +#' probabilities = c(Class1, Class2) +#' ) +#' +#' predict(post_res, two_class_example) +#' @export +adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) { + check_container(x) + cmds <- enquos(...) + + op <- + new_operation( + "predictions_custom", + inputs = "everything", + outputs = "everything", + arguments = list(commands = cmds, pkgs = .pkgs), + results = list(), + trained = FALSE + ) + + new_container( + mode = x$mode, + type = x$type, + operations = c(x$operations, list(op)), + columns = x$dat, + ptype = x$ptype, + call = current_env() + ) +} + +#' @export +print.predictions_custom <- function(x, ...) { + trn <- ifelse(x$trained, " [trained]", "") + cli::cli_bullets(c("*" = "Adjust predictions using custom code.{trn}")) + invisible(x) +} + +#' @export +fit.predictions_custom <- function(object, data, container = NULL, ...) { + new_operation( + class(object), + inputs = object$inputs, + outputs = object$outputs, + arguments = object$arguments, + results = list(), + trained = TRUE + ) +} + +#' @export +predict.predictions_custom <- function(object, new_data, container, ...) { + dplyr::mutate(new_data, !!!object$arguments$commands) +} + +#' @export +required_pkgs.predictions_custom <- function(x, ...) { + unique(c("container", x$arguments$pkgs)) +} + +#' @export +tunable.predictions_custom <- function(x, ...) { + no_param +} + +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/adjust-probability-calibration.R b/R/adjust-probability-calibration.R new file mode 100644 index 0000000..9808285 --- /dev/null +++ b/R/adjust-probability-calibration.R @@ -0,0 +1,77 @@ +#' Re-calibrate classification probability predictions +#' +#' @param x A [container()]. +#' @param calibrator A pre-trained calibration method from the \pkg{probably} +#' package, such as [probably::cal_estimate_logistic()]. +#' @export +adjust_probability_calibration <- function(x, calibrator) { + check_container(x) + cls <- c("cal_binary", "cal_multinomial") + check_required(calibrator) + if (!inherits_any(calibrator, cls)) { + cli_abort( + "{.arg calibrator} should be a \\ + {.help [ or object](probably::cal_estimate_logistic)}, \\ + not {.obj_type_friendly {calibrator}}." + ) + } + + op <- + new_operation( + "probability_calibration", + inputs = "probability", + outputs = "probability_class", + arguments = list(calibrator = calibrator), + results = list(), + trained = FALSE + ) + + new_container( + mode = x$mode, + type = x$type, + operations = c(x$operations, list(op)), + columns = x$dat, + ptype = x$ptype, + call = current_env() + ) +} + +#' @export +print.probability_calibration <- function(x, ...) { + trn <- ifelse(x$trained, " [trained]", "") + cli::cli_bullets(c("*" = "Re-calibrate classification probabilities.{trn}")) + invisible(x) +} + +#' @export +fit.probability_calibration <- function(object, data, container = NULL, ...) { + new_operation( + class(object), + inputs = object$inputs, + outputs = object$outputs, + arguments = object$arguments, + results = list(), + trained = TRUE + ) +} + +#' @export +predict.probability_calibration <- function(object, new_data, container, ...) { + probably::cal_apply(new_data, object$argument$calibrator) +} + +# todo probably needs required_pkgs methods for cal objects +#' @export +required_pkgs.probability_calibration <- function(x, ...) { + c("container", "probably") +} + +#' @export +tunable.probability_calibration <- function(x, ...) { + no_param +} + +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/adjust-probability-threshold.R b/R/adjust-probability-threshold.R new file mode 100644 index 0000000..fcf13a3 --- /dev/null +++ b/R/adjust-probability-threshold.R @@ -0,0 +1,110 @@ +#' Change the event threshold +#' +#' @param x A [container()]. +#' @param threshold A numeric value (between zero and one) or [hardhat::tune()]. +#' @examples +#' library(dplyr) +#' library(modeldata) +#' +#' post_obj <- +#' container(mode = "classification") %>% +#' adjust_probability_threshold(threshold = .1) +#' +#' two_class_example %>% count(predicted) +#' +#' post_res <- fit( +#' post_obj, +#' two_class_example, +#' outcome = c(truth), +#' estimate = c(predicted), +#' probabilities = c(Class1, Class2) +#' ) +#' +#' predict(post_res, two_class_example) %>% count(predicted) +#' @export +adjust_probability_threshold <- function(x, threshold = 0.5) { + check_container(x) + if (!is_tune(threshold)) { + check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10) + } + + op <- + new_operation( + "probability_threshold", + inputs = "probability", + outputs = "class", + arguments = list(threshold = threshold), + results = list(), + trained = FALSE + ) + + new_container( + mode = x$mode, + type = x$type, + operations = c(x$operations, list(op)), + columns = x$dat, + ptype = x$ptype, + call = current_env() + ) +} + +#' @export +print.probability_threshold <- function(x, ...) { + # check for tune() first + + if (is_tune(x$arguments$threshold)) { + cli::cli_bullets(c("*" = "Adjust probability threshold to optimized value.")) + } else { + trn <- ifelse(x$trained, " [trained]", "") + cli::cli_bullets(c( + "*" = "Adjust probability threshold to \\ + {signif(x$arguments$threshold, digits = 3)}.{trn}" + )) + } + invisible(x) +} + +#' @export +fit.probability_threshold <- function(object, data, container = NULL, ...) { + new_operation( + class(object), + inputs = object$inputs, + outputs = object$outputs, + arguments = object$arguments, + results = list(), + trained = TRUE + ) +} + +#' @export +predict.probability_threshold <- function(object, new_data, container, ...) { + est_nm <- container$columns$estimate + prob_nm <- container$columns$probabilities[1] + lvls <- levels(new_data[[est_nm]]) + + new_data[[est_nm]] <- + ifelse(new_data[[prob_nm]] >= object$arguments$threshold, lvls[1], lvls[2]) + new_data[[est_nm]] <- factor(new_data[[est_nm]], levels = lvls) + new_data +} + +#' @export +required_pkgs.probability_threshold <- function(x, ...) { + c("container") +} + +#' @export +tunable.probability_threshold <- function(x, ...) { + tibble::new_tibble(list( + name = "threshold", + call_info = list(list(pkg = "dials", fun = "threshold")), + source = "container", + component = "probability_threshold", + component_id = "probability_threshold" + )) +} + +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/container-package.R b/R/container-package.R index 9e15a93..361e307 100644 --- a/R/container-package.R +++ b/R/container-package.R @@ -1,8 +1,10 @@ #' @import rlang #' @importFrom cli cli_abort cli_warn cli_inform +#' @importFrom stats predict #' @keywords internal "_PACKAGE" ## usethis namespace: start +utils::globalVariables("data") ## usethis namespace: end NULL diff --git a/R/container.R b/R/container.R new file mode 100644 index 0000000..e199a7d --- /dev/null +++ b/R/container.R @@ -0,0 +1,178 @@ +#' Declare post-processing for model predictions +#' +#' @param mode The model's mode, one of `"classification"`, or `"regression"`. +#' Modes of `"censored regression"` are not currently supported. +#' @param type The model sub-type. Possible values are `"unknown"`, `"regression"`, +#' `"binary"`, or `"multiclass"`. +#' @param outcome The name of the outcome variable. +#' @param estimate The name of the point estimate (e.g. predicted class). In +#' tidymodels, this corresponds to column names `.pred`, `.pred_class`, or +#' `.pred_time`. +#' @param probabilities The names of class probability estimates (if any). For +#' classification, these should be given in the order of the factor levels of +#' the `estimate`. +#' @param time The name of the predicted event time. (not yet supported) +#' @examples +#' +#' container(mode = "regression") +#' @export +container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL, + probabilities = NULL, time = NULL) { + columns <- + list( + outcome = outcome, + type = type, + estimate = estimate, + probabilities = probabilities, + time = time + ) + + new_container( + mode, + type, + operations = list(), + columns = columns, + ptype = tibble::new_tibble(list()), + call = current_env() + ) +} + +new_container <- function(mode, type, operations, columns, ptype, call) { + mode <- arg_match0(mode, c("regression", "classification")) + + if (mode == "regression") { + type <- "regression" + } + + type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass")) + + if (!is.list(operations)) { + cli_abort("The {.arg operations} argument should be a list.", call = call) + } + + is_oper <- purrr::map_lgl(operations, ~ inherits(.x, "operation")) + if (length(is_oper) > 0 && !any(is_oper)) { + bad_oper <- names(is_oper)[!is_oper] + cli_abort("The following {.arg operations} do not have the class \\ + {.val operation}: {bad_oper}.", call = call) + } + + # validate operation order and check duplicates + validate_order(operations, mode, call) + + # check columns + res <- list( + mode = mode, type = type, operations = operations, + columns = columns, ptype = ptype + ) + class(res) <- "container" + res +} + +#' @export +print.container <- function(x, ...) { + cli::cli_h1("Container") + + num_op <- length(x$operations) + cli::cli_text( + "A {ifelse(x$type == 'unknown', '', x$type)} postprocessor \\ + with {num_op} operation{?s}{cli::qty(num_op+1)}{?./:}" + ) + + if (num_op > 0) { + cli::cli_text("\n") + res <- purrr::map(x$operations, print) + } + + invisible(x) +} + +#' @export +fit.container <- function(object, .data, outcome, estimate, probabilities = c(), + time = c(), ...) { + # ------------------------------------------------------------------------------ + # set columns via tidyselect + + columns <- list() + columns$outcome <- names(tidyselect::eval_select(enquo(outcome), .data)) + columns$estimate <- names(tidyselect::eval_select(enquo(estimate), .data)) + + probabilities <- tidyselect::eval_select(enquo(probabilities), .data) + if (length(probabilities) > 0) { + columns$probabilities <- names(probabilities) + } else { + columns$probabilities <- character(0) + } + + time <- tidyselect::eval_select(enquo(time), .data) + if (length(time) > 0) { + columns$time <- names(time) + } else { + columns$time <- character(0) + } + + .data <- .data[, names(.data) %in% unlist(columns)] + if (!tibble::is_tibble(.data)) { + .data <- tibble::as_tibble(.data) + } + ptype <- .data[0, ] + + object <- set_container_type(object, .data[[columns$outcome]]) + + object <- new_container( + object$mode, + object$type, + operations = object$operations, + columns = columns, + ptype = ptype, + call = current_env() + ) + + num_oper <- length(object$operations) + for (op in seq_len(num_oper)) { + object$operations[[op]] <- fit(object$operations[[op]], data, object) + .data <- predict(object$operations[[op]], .data, object) + } + + # todo Add a fitted container class? + object +} + +#' @export +predict.container <- function(object, new_data, ...) { + # validate levels/classes + num_oper <- length(object$operations) + for (op in seq_len(num_oper)) { + new_data <- predict(object$operations[[op]], new_data, object) + } + if (!tibble::is_tibble(new_data)) { + new_data <- tibble::as_tibble(new_data) + } + new_data +} + +set_container_type <- function(object, y) { + if (object$type != "unknown") { + return(object) + } + if (is.factor(y)) { + lvls <- levels(y) + if (length(lvls) == 2) { + object$type <- "binary" + } else { + object$type <- "multiclass" + } + } else if (is.numeric(y)) { + object$type <- "regression" + } else { + cli_abort("Only factor and numeric outcomes are currently supported.") + } + object +} + +# todo: where to validate #levels? +# todo setup eval_time +# todo missing methods: +# todo tune_args +# todo tidy +# todo extract_parameter_set_dials diff --git a/R/reexport.R b/R/reexport.R new file mode 100644 index 0000000..c0802e3 --- /dev/null +++ b/R/reexport.R @@ -0,0 +1,31 @@ +#' @importFrom generics fit +#' @export +generics::fit + +#' @importFrom generics tidy +#' @export +generics::tidy + +#' @importFrom generics required_pkgs +#' @export +generics::required_pkgs + +#' @importFrom generics tunable +#' @export +generics::tunable + +#' @importFrom generics tune_args +#' @export +generics::tune_args + +#' @importFrom hardhat extract_parameter_set_dials +#' @export +hardhat::extract_parameter_set_dials + +#' @importFrom hardhat extract_parameter_dials +#' @export +hardhat::extract_parameter_dials + +#' @importFrom dplyr %>% +#' @export +dplyr::`%>%` diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 0000000..e90ed46 --- /dev/null +++ b/R/utils.R @@ -0,0 +1,62 @@ +is_tune <- function(x) { + if (!is.call(x)) { + return(FALSE) + } + isTRUE(identical(quote(tune), x[[1]])) +} + +# for operations with no tunable parameters + +no_param <- + tibble::tibble( + name = character(0), + call_info = list(), + source = character(0), + component = character(0), + component_id = character(0) + ) + +# These values are used to specify "what will we need for the adjustment?" and +# "what will we change?". For the outputs, we cannot change the probabilities +# without changing the classes. This is important because we are going to have +# to define constrains on the order of adjustments. + +input_vals <- c("numeric", "probability", "class", "everything") +output_vals <- c("numeric", "probability_class", "class", "everything") + +new_operation <- function(cls, inputs, outputs, arguments, results = list(), + trained, ...) { + inputs <- arg_match0(inputs, input_vals) + outputs <- arg_match0(outputs, output_vals) + + check_logical(trained) + + res <- + list( + inputs = inputs, + outputs = outputs, + arguments = arguments, + results = results, + trained = trained + ) + class(res) <- c(cls, "operation") + res +} + +# predicates ------------------------------------------------------------------- +is_container <- function(x) { + inherits(x, "container") +} + +# ad-hoc checking -------------------------------------------------------------- +check_container <- function(x, call = caller_env(), arg = caller_arg(x)) { + if (!is_container(x)) { + cli::cli_abort( + "{.arg {arg}} should be a {.help [{.cls container}](container::container)}, \\ + not {.obj_type_friendly {x}}.", + call = call + ) + } + + invisible() +} diff --git a/R/validation-rules.R b/R/validation-rules.R new file mode 100644 index 0000000..a6a74a7 --- /dev/null +++ b/R/validation-rules.R @@ -0,0 +1,85 @@ +validate_order <- function(ops, mode, call) { + orderings <- + tibble::new_tibble(list( + name = purrr::map_chr(ops, ~ class(.x)[1]), + input = purrr::map_chr(ops, ~ .x$inputs), + output_numeric = purrr::map_lgl(ops, ~ grepl("numeric", .x$outputs)), + output_prob = purrr::map_lgl(ops, ~ grepl("probability", .x$outputs)), + output_class = purrr::map_lgl(ops, ~ grepl("class", .x$outputs)), + output_all = purrr::map_lgl(ops, ~ grepl("everything", .x$outputs)) + )) + + if (length(ops) < 2) { + return(invisible(orderings)) + } + + if (mode == "classification") { + check_classification_order(orderings, call) + } else { + check_regression_order(orderings, call) + } + + invisible(orderings) +} + +check_classification_order <- function(x, call) { + cal_ind <- which(grepl("calibration$", x$name)) + eq_ind <- which(grepl("equivocal", x$name)) + prob_ind <- which(x$output_prob) + class_ind <- which(x$output_class) + + # does probability steps come after steps that change the hard classes? + if (length(prob_ind) > 0) { + if (any(class_ind < prob_ind)) { + cli_abort("Operations that change the hard class predictions \\ + must come after operations that update the class \\ + probability estimates.", call = call) + } + } + + # todo ? calibration should _probably_ come before anything that is not a mutate + + # do any steps come before Eq zones + if (length(eq_ind) > 0) { + if (any(eq_ind < class_ind) | any(eq_ind < prob_ind)) { + cli_abort("Equivocal zone addition should come after operations \\ + that update the class probability estimates or hard \\ + class predictions.", call = call) + } + } + + # besides mutates, are there duplicate steps? + check_duplicates(x, call) + + invisible(x) +} + +check_regression_order <- function(x, call) { + cal_ind <- which(grepl("calibration$", x$name)) + num_ind <- which(x$output_numeric) + + # does calibration come after other steps? + # currently excluding mutates form this check + if (length(cal_ind) > 0) { + if (any(num_ind < cal_ind)) { + cli_abort( + "Calibration should come before other operations.", + call = call + ) + } + } + + # besides mutates, are there duplicate steps? + check_duplicates(x, call) + + invisible(x) +} + +check_duplicates <- function(x, call) { + non_mutates <- table(x$name[x$name != "predictions_custom"]) + if (any(non_mutates > 1)) { + bad_oper <- names(non_mutates[non_mutates > 1]) + cli_abort("Operations cannot be duplicated: {.val {bad_oper}}", call = call) + } + invisible(x) +} diff --git a/README.Rmd b/README.Rmd index 6f2aeab..f1206b1 100644 --- a/README.Rmd +++ b/README.Rmd @@ -23,6 +23,8 @@ knitr::opts_chunk$set( The goal of container is to provide a container for postprocessing. +This is going to undergo massive changes (especially the name), so please treat it as experimental and don't depend on the syntax staying the same. + ## Installation You can install the development version of container like so: diff --git a/README.md b/README.md index 18d74cf..0356729 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ status](https://www.r-pkg.org/badges/version/container)](https://CRAN.R-project. The goal of container is to provide a container for postprocessing. +This is going to undergo massive changes (especially the name), so +please treat it as experimental and don’t depend on the syntax staying +the same. + ## Installation You can install the development version of container like so: diff --git a/inst/examples/container_regression_example.pdf b/inst/examples/container_regression_example.pdf new file mode 100644 index 0000000..7fab1ff Binary files /dev/null and b/inst/examples/container_regression_example.pdf differ diff --git a/inst/examples/container_regression_example.qmd b/inst/examples/container_regression_example.qmd new file mode 100644 index 0000000..bebdb00 --- /dev/null +++ b/inst/examples/container_regression_example.qmd @@ -0,0 +1,154 @@ +--- +title: "container regression example" +--- + +This is an example regression analysis to show how the container package might work. + +We'll use the [food delivery data](https://aml4td.org/chapters/whole-game.html) and start with a three-way split: + +```{r} +#| label: ssshhh +#| include: false + +library(tidymodels) +library(bonsai) +library(container) +library(probably) +library(patchwork) +``` +```{r} +#| label: startup +library(tidymodels) +library(bonsai) # also requires lightgbm package +library(container) # pak::pak(c("tidymodels/container@max"), ask = FALSE) +library(probably) +library(patchwork) + +# ------------------------------------------------------------------------------ + +tidymodels_prefer() +theme_set(theme_bw()) +options(pillar.advice = FALSE, pillar.min_title_chars = Inf) + +# ------------------------------------------------------------------------------ + +data(deliveries, package = "modeldata") + +set.seed(991) +delivery_split <- initial_validation_split(deliveries, prop = c(0.6, 0.2), + strata = time_to_delivery) +delivery_train <- training(delivery_split) +delivery_test <- testing(delivery_split) +delivery_val <- validation(delivery_split) +``` + +Let's deliberately fit a regression model that has poor predicted values: a boosted tree with only three ensemble members: + +```{r} +#| label: bad-boost + +bst_fit <- + boost_tree(trees = 3) %>% + set_engine("lightgbm") %>% + set_mode("regression") %>% + fit(time_to_delivery ~ ., data = delivery_train) +``` + +We predict the validation set and see how bad things are: + +```{r} +#| label: bad-pred + +reg_metrics <- metric_set(rmse, rsq) + +bst_val_pred <- augment(bst_fit, delivery_val) +reg_metrics(bst_val_pred, truth = time_to_delivery, estimate = .pred) +``` + +That R2 looks _great_! How well is it calibrated? + +```{r} +#| label: bad-pred-plot +cal_plot_regression(bst_val_pred, truth = time_to_delivery, estimate = .pred) +``` + + +Ooof. One of the calibration tools for the probably package might help this. Let's use a linear regression with spline terms to fix it. First, we'll resample the calibration model to see if it helps: + +```{r} +#| label: cal-resample + +set.seed(10) +bst_val_pred %>% + vfold_cv() %>% + cal_validate_linear(truth = time_to_delivery, estimate = .pred, + smooth = TRUE, metrics = reg_metrics) %>% + collect_metrics() +``` + +That seems promising. Let's fit it to the validation set predictions: + +```{r} +#| label: cal-obj + +bst_cal <- cal_estimate_linear(bst_val_pred, truth = time_to_delivery, + estimate = .pred, smooth = TRUE) +``` + +We could manually use `cal_apply()` to adjust predictions, but instead, we'll add it to the post-processing object: + +```{r} +#| label: post-1 + +post_obj <- + container(mode = "regression") %>% + adjust_numeric_calibration(bst_cal) +post_obj +``` + +Let's add another post-processor to limit the range of predictions (just as a demonstration): + +```{r} +#| label: post-2 + +post_obj <- + post_obj %>% + adjust_numeric_range(lower_limit = 0, upper_limit = 50) +post_obj +``` + +We have to fit the post-processor to use it. However, there are no estimation steps in this instance since everything is either pre-trained (e.g., the calibrator) or user-defined (e.g., the limits). We'll run `fit()` anyway, then apply it to the test results: + +```{r} +#| label: test-pred + +post_res <- + post_obj %>% + fit(bst_val_pred, outcome = c(time_to_delivery), estimate = c(.pred)) + +bst_test_pred <- augment(bst_fit, delivery_test) + +# Without: +reg_metrics(bst_test_pred, truth = time_to_delivery, estimate = .pred) + +# With: +bst_test_proc_pred <- + post_res %>% + predict(bst_test_pred) + +bst_test_proc_pred %>% + reg_metrics(truth = time_to_delivery, estimate = .pred) +``` + +Visually: + +```{r} +#| label: test-plot + +before <- cal_plot_regression(bst_test_pred, truth = time_to_delivery, + estimate = .pred) +after <- cal_plot_regression(bst_test_proc_pred, truth = time_to_delivery, + estimate = .pred) + +before + after +``` diff --git a/man/adjust_equivocal_zone.Rd b/man/adjust_equivocal_zone.Rd new file mode 100644 index 0000000..705f41e --- /dev/null +++ b/man/adjust_equivocal_zone.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/adjust-equivocal-zone.R +\name{adjust_equivocal_zone} +\alias{adjust_equivocal_zone} +\title{Apply an equivocal zone to a binary classification model.} +\usage{ +adjust_equivocal_zone(x, value = 0.1, threshold = 1/2) +} +\arguments{ +\item{x}{A \code{\link[=container]{container()}}.} + +\item{value}{A numeric value (between zero and 1/2) or \code{\link[hardhat:tune]{hardhat::tune()}}. The +value is the size of the buffer around the threshold.} + +\item{threshold}{A numeric value (between zero and one) or \code{\link[hardhat:tune]{hardhat::tune()}}.} +} +\description{ +Apply an equivocal zone to a binary classification model. +} +\examples{ +library(dplyr) +library(modeldata) + +post_obj <- + container(mode = "classification") \%>\% + adjust_equivocal_zone(value = 1 / 4) + + +post_res <- fit( + post_obj, + two_class_example, + outcome = c(truth), + estimate = c(predicted), + probabilities = c(Class1, Class2) +) + +predict(post_res, two_class_example) +} diff --git a/man/adjust_numeric_calibration.Rd b/man/adjust_numeric_calibration.Rd new file mode 100644 index 0000000..f8e6315 --- /dev/null +++ b/man/adjust_numeric_calibration.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/adjust-numeric-calibration.R +\name{adjust_numeric_calibration} +\alias{adjust_numeric_calibration} +\title{Re-calibrate numeric predictions} +\usage{ +adjust_numeric_calibration(x, calibrator) +} +\arguments{ +\item{x}{A \code{\link[=container]{container()}}.} + +\item{calibrator}{A pre-trained calibration method from the \pkg{probably} +package, such as \code{\link[probably:cal_estimate_linear]{probably::cal_estimate_linear()}}.} +} +\description{ +Re-calibrate numeric predictions +} +\examples{ +library(modeldata) +library(probably) +library(tibble) + +# create example data +set.seed(1) +dat <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100)) + +dat + +# calibrate numeric predictions +reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred) + +# specify calibration +reg_ctr <- + container(mode = "regression") \%>\% + adjust_numeric_calibration(reg_cal) + +# "train" container +reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred) + +predict(reg_ctr, dat) +} diff --git a/man/adjust_numeric_range.Rd b/man/adjust_numeric_range.Rd new file mode 100644 index 0000000..08911f8 --- /dev/null +++ b/man/adjust_numeric_range.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/adjust-numeric-range.R +\name{adjust_numeric_range} +\alias{adjust_numeric_range} +\title{Truncate the range of numeric predictions} +\usage{ +adjust_numeric_range(x, lower_limit = -Inf, upper_limit = Inf) +} +\arguments{ +\item{x}{A \code{\link[=container]{container()}}.} + +\item{upper_limit, lower_limit}{A numeric value, NA (for no truncation) or +\code{\link[hardhat:tune]{hardhat::tune()}}.} +} +\description{ +Truncate the range of numeric predictions +} diff --git a/man/adjust_predictions_custom.Rd b/man/adjust_predictions_custom.Rd new file mode 100644 index 0000000..4f54ced --- /dev/null +++ b/man/adjust_predictions_custom.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/adjust-predictions-custom.R +\name{adjust_predictions_custom} +\alias{adjust_predictions_custom} +\title{Change or add variables} +\usage{ +adjust_predictions_custom(x, ..., .pkgs = character(0)) +} +\arguments{ +\item{x}{A \code{\link[=container]{container()}}.} + +\item{...}{Name-value pairs of expressions. See \code{\link[dplyr:mutate]{dplyr::mutate()}}.} + +\item{.pkgs}{A character string of extra packages that are needed to execute +the commands.} +} +\description{ +Change or add variables +} +\examples{ +library(dplyr) +library(modeldata) + +post_obj <- + container(mode = "classification") \%>\% + adjust_equivocal_zone() \%>\% + adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2)) + + +post_res <- fit( + post_obj, + two_class_example, + outcome = c(truth), + estimate = c(predicted), + probabilities = c(Class1, Class2) +) + +predict(post_res, two_class_example) +} diff --git a/man/adjust_probability_calibration.Rd b/man/adjust_probability_calibration.Rd new file mode 100644 index 0000000..65e0392 --- /dev/null +++ b/man/adjust_probability_calibration.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/adjust-probability-calibration.R +\name{adjust_probability_calibration} +\alias{adjust_probability_calibration} +\title{Re-calibrate classification probability predictions} +\usage{ +adjust_probability_calibration(x, calibrator) +} +\arguments{ +\item{x}{A \code{\link[=container]{container()}}.} + +\item{calibrator}{A pre-trained calibration method from the \pkg{probably} +package, such as \code{\link[probably:cal_estimate_logistic]{probably::cal_estimate_logistic()}}.} +} +\description{ +Re-calibrate classification probability predictions +} diff --git a/man/adjust_probability_threshold.Rd b/man/adjust_probability_threshold.Rd new file mode 100644 index 0000000..ea227dd --- /dev/null +++ b/man/adjust_probability_threshold.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/adjust-probability-threshold.R +\name{adjust_probability_threshold} +\alias{adjust_probability_threshold} +\title{Change the event threshold} +\usage{ +adjust_probability_threshold(x, threshold = 0.5) +} +\arguments{ +\item{x}{A \code{\link[=container]{container()}}.} + +\item{threshold}{A numeric value (between zero and one) or \code{\link[hardhat:tune]{hardhat::tune()}}.} +} +\description{ +Change the event threshold +} +\examples{ +library(dplyr) +library(modeldata) + +post_obj <- + container(mode = "classification") \%>\% + adjust_probability_threshold(threshold = .1) + +two_class_example \%>\% count(predicted) + +post_res <- fit( + post_obj, + two_class_example, + outcome = c(truth), + estimate = c(predicted), + probabilities = c(Class1, Class2) +) + +predict(post_res, two_class_example) \%>\% count(predicted) +} diff --git a/man/container-package.Rd b/man/container-package.Rd index 69fa1b8..8f314aa 100644 --- a/man/container-package.Rd +++ b/man/container-package.Rd @@ -2,7 +2,6 @@ % Please edit documentation in R/container-package.R \docType{package} \name{container-package} -\alias{container} \alias{container-package} \title{container: Sandbox for a postprocessor object} \description{ diff --git a/man/container.Rd b/man/container.Rd new file mode 100644 index 0000000..1d074b5 --- /dev/null +++ b/man/container.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/container.R +\name{container} +\alias{container} +\title{Declare post-processing for model predictions} +\usage{ +container( + mode, + type = "unknown", + outcome = NULL, + estimate = NULL, + probabilities = NULL, + time = NULL +) +} +\arguments{ +\item{mode}{The model's mode, one of \code{"classification"}, or \code{"regression"}. +Modes of \code{"censored regression"} are not currently supported.} + +\item{type}{The model sub-type. Possible values are \code{"unknown"}, \code{"regression"}, +\code{"binary"}, or \code{"multiclass"}.} + +\item{outcome}{The name of the outcome variable.} + +\item{estimate}{The name of the point estimate (e.g. predicted class). In +tidymodels, this corresponds to column names \code{.pred}, \code{.pred_class}, or +\code{.pred_time}.} + +\item{probabilities}{The names of class probability estimates (if any). For +classification, these should be given in the order of the factor levels of +the \code{estimate}.} + +\item{time}{The name of the predicted event time. (not yet supported)} +} +\description{ +Declare post-processing for model predictions +} +\examples{ + +container(mode = "regression") +} diff --git a/man/reexports.Rd b/man/reexports.Rd new file mode 100644 index 0000000..43758ef --- /dev/null +++ b/man/reexports.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/reexport.R +\docType{import} +\name{reexports} +\alias{reexports} +\alias{fit} +\alias{tidy} +\alias{required_pkgs} +\alias{tunable} +\alias{tune_args} +\alias{extract_parameter_set_dials} +\alias{extract_parameter_dials} +\alias{\%>\%} +\title{Objects exported from other packages} +\keyword{internal} +\description{ +These objects are imported from other packages. Follow the links +below to see their documentation. + +\describe{ + \item{dplyr}{\code{\link[dplyr:reexports]{\%>\%}}} + + \item{generics}{\code{\link[generics]{fit}}, \code{\link[generics]{required_pkgs}}, \code{\link[generics]{tidy}}, \code{\link[generics]{tunable}}, \code{\link[generics]{tune_args}}} + + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}} +}} + diff --git a/tests/testthat/_snaps/adjust-equivocal-zone.md b/tests/testthat/_snaps/adjust-equivocal-zone.md new file mode 100644 index 0000000..4021efc --- /dev/null +++ b/tests/testthat/_snaps/adjust-equivocal-zone.md @@ -0,0 +1,22 @@ +# adjustment printing + + Code + ctr_cls %>% adjust_equivocal_zone() + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Add equivocal zone of size 0.1. + +--- + + Code + ctr_cls %>% adjust_equivocal_zone(hardhat::tune()) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Add equivocal zone of optimized size. + diff --git a/tests/testthat/_snaps/adjust-numeric-calibration.md b/tests/testthat/_snaps/adjust-numeric-calibration.md new file mode 100644 index 0000000..dd4d98c --- /dev/null +++ b/tests/testthat/_snaps/adjust-numeric-calibration.md @@ -0,0 +1,35 @@ +# adjustment printing + + Code + ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Re-calibrate numeric predictions. + +# errors informatively with bad input + + Code + adjust_numeric_calibration(ctr_reg) + Condition + Error in `adjust_numeric_calibration()`: + ! `calibrator` is absent but must be supplied. + +--- + + Code + adjust_numeric_calibration(ctr_reg, "boop") + Condition + Error in `adjust_numeric_calibration()`: + ! `calibrator` should be a object (`?probably::cal_estimate_linear()`), not a string. + +--- + + Code + adjust_numeric_calibration(ctr_cls, dummy_cls_cal) + Condition + Error in `adjust_numeric_calibration()`: + ! `calibrator` should be a object (`?probably::cal_estimate_linear()`), not a object. + diff --git a/tests/testthat/_snaps/adjust-numeric-range.md b/tests/testthat/_snaps/adjust-numeric-range.md new file mode 100644 index 0000000..afa9537 --- /dev/null +++ b/tests/testthat/_snaps/adjust-numeric-range.md @@ -0,0 +1,44 @@ +# adjustment printing + + Code + ctr_reg %>% adjust_numeric_range() + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Constrain numeric predictions to be between [-Inf, Inf]. + +--- + + Code + ctr_reg %>% adjust_numeric_range(hardhat::tune()) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Constrain numeric predictions to be between [?, Inf]. + +--- + + Code + ctr_reg %>% adjust_numeric_range(-1, hardhat::tune()) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Constrain numeric predictions to be between [-1, ?]. + +--- + + Code + ctr_reg %>% adjust_numeric_range(hardhat::tune(), 1) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Constrain numeric predictions to be between [?, 1]. + diff --git a/tests/testthat/_snaps/adjust-predictions-custom.md b/tests/testthat/_snaps/adjust-predictions-custom.md new file mode 100644 index 0000000..71e9e76 --- /dev/null +++ b/tests/testthat/_snaps/adjust-predictions-custom.md @@ -0,0 +1,11 @@ +# adjustment printing + + Code + ctr_cls %>% adjust_predictions_custom() + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Adjust predictions using custom code. + diff --git a/tests/testthat/_snaps/adjust-probability-calibration.md b/tests/testthat/_snaps/adjust-probability-calibration.md new file mode 100644 index 0000000..52a037e --- /dev/null +++ b/tests/testthat/_snaps/adjust-probability-calibration.md @@ -0,0 +1,35 @@ +# adjustment printing + + Code + ctr_cls %>% adjust_probability_calibration(dummy_cls_cal) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Re-calibrate classification probabilities. + +# errors informatively with bad input + + Code + adjust_probability_calibration(ctr_cls) + Condition + Error in `adjust_probability_calibration()`: + ! `calibrator` is absent but must be supplied. + +--- + + Code + adjust_probability_calibration(ctr_cls, "boop") + Condition + Error in `adjust_probability_calibration()`: + ! `calibrator` should be a or object (`?probably::cal_estimate_logistic()`), not a string. + +--- + + Code + adjust_probability_calibration(ctr_cls, dummy_reg_cal) + Condition + Error in `adjust_probability_calibration()`: + ! `calibrator` should be a or object (`?probably::cal_estimate_logistic()`), not a object. + diff --git a/tests/testthat/_snaps/adjust-probability-threshold.md b/tests/testthat/_snaps/adjust-probability-threshold.md new file mode 100644 index 0000000..2affcca --- /dev/null +++ b/tests/testthat/_snaps/adjust-probability-threshold.md @@ -0,0 +1,22 @@ +# adjustment printing + + Code + ctr_cls %>% adjust_probability_threshold() + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Adjust probability threshold to 0.5. + +--- + + Code + ctr_cls %>% adjust_probability_threshold(hardhat::tune()) + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 1 operation: + + * Adjust probability threshold to optimized value. + diff --git a/tests/testthat/_snaps/container.md b/tests/testthat/_snaps/container.md new file mode 100644 index 0000000..24aa650 --- /dev/null +++ b/tests/testthat/_snaps/container.md @@ -0,0 +1,43 @@ +# container printing + + Code + ctr_cls + Message + + -- Container ------------------------------------------------------------------- + A postprocessor with 0 operations. + +--- + + Code + container(mode = "classification", type = "binary") + Message + + -- Container ------------------------------------------------------------------- + A binary postprocessor with 0 operations. + +--- + + Code + container(mode = "classification", type = "binary") %>% + adjust_probability_threshold(0.2) + Message + + -- Container ------------------------------------------------------------------- + A binary postprocessor with 1 operation: + + * Adjust probability threshold to 0.2. + +--- + + Code + container(mode = "classification", type = "binary") %>% + adjust_probability_threshold(0.2) %>% adjust_equivocal_zone() + Message + + -- Container ------------------------------------------------------------------- + A binary postprocessor with 2 operations: + + * Adjust probability threshold to 0.2. + * Add equivocal zone of size 0.1. + diff --git a/tests/testthat/_snaps/utils.md b/tests/testthat/_snaps/utils.md new file mode 100644 index 0000000..8e52628 --- /dev/null +++ b/tests/testthat/_snaps/utils.md @@ -0,0 +1,8 @@ +# check_container raises informative error + + Code + adjust_probability_threshold("boop") + Condition + Error in `adjust_probability_threshold()`: + ! `x` should be a (`?container::container()`), not a string. + diff --git a/tests/testthat/_snaps/validation-rules.md b/tests/testthat/_snaps/validation-rules.md new file mode 100644 index 0000000..a4ae358 --- /dev/null +++ b/tests/testthat/_snaps/validation-rules.md @@ -0,0 +1,49 @@ +# validation of operations (regression) + + Code + container(mode = "regression") %>% adjust_numeric_range(lower_limit = 2) %>% + adjust_numeric_calibration(dummy_reg_cal) %>% adjust_predictions_custom( + squared = .pred^2) + Condition + Error in `adjust_numeric_calibration()`: + ! Calibration should come before other operations. + +# validation of operations (classification) + + Code + container(mode = "classification") %>% adjust_probability_threshold(threshold = 0.4) %>% + adjust_probability_calibration(dummy_cls_cal) + Condition + Error in `adjust_probability_calibration()`: + ! Operations that change the hard class predictions must come after operations that update the class probability estimates. + +--- + + Code + container(mode = "classification") %>% adjust_predictions_custom(veg = "potato") %>% + adjust_probability_threshold(threshold = 0.4) %>% + adjust_probability_calibration(dummy_cls_cal) + Condition + Error in `adjust_probability_calibration()`: + ! Operations that change the hard class predictions must come after operations that update the class probability estimates. + +--- + + Code + container(mode = "classification") %>% adjust_predictions_custom(veg = "potato") %>% + adjust_probability_threshold(threshold = 0.4) %>% + adjust_probability_threshold(threshold = 0.5) %>% + adjust_probability_calibration(dummy_cls_cal) + Condition + Error in `adjust_probability_threshold()`: + ! Operations cannot be duplicated: "probability_threshold" + +--- + + Code + container(mode = "classification") %>% adjust_equivocal_zone(value = 0.2) %>% + adjust_probability_threshold(threshold = 0.4) + Condition + Error in `adjust_probability_threshold()`: + ! Equivocal zone addition should come after operations that update the class probability estimates or hard class predictions. + diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R new file mode 100644 index 0000000..1299a28 --- /dev/null +++ b/tests/testthat/helper-objects.R @@ -0,0 +1,2 @@ +ctr_cls <- container(mode = "classification") +ctr_reg <- container(mode = "classification") diff --git a/tests/testthat/test-adjust-equivocal-zone.R b/tests/testthat/test-adjust-equivocal-zone.R new file mode 100644 index 0000000..ab0793e --- /dev/null +++ b/tests/testthat/test-adjust-equivocal-zone.R @@ -0,0 +1,4 @@ +test_that("adjustment printing", { + expect_snapshot(ctr_cls %>% adjust_equivocal_zone()) + expect_snapshot(ctr_cls %>% adjust_equivocal_zone(hardhat::tune())) +}) diff --git a/tests/testthat/test-adjust-numeric-calibration.R b/tests/testthat/test-adjust-numeric-calibration.R new file mode 100644 index 0000000..b67b717 --- /dev/null +++ b/tests/testthat/test-adjust-numeric-calibration.R @@ -0,0 +1,13 @@ +test_that("adjustment printing", { + dummy_reg_cal <- structure(list(), class = "cal_regression") + expect_snapshot(ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal)) +}) + +test_that("errors informatively with bad input", { + # check for `adjust_numeric_calibration(container)` is in `utils.R` tests + + expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_reg)) + expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_reg, "boop")) + dummy_cls_cal <- structure(list(), class = "cal_binary") + expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_cls, dummy_cls_cal)) +}) diff --git a/tests/testthat/test-adjust-numeric-range.R b/tests/testthat/test-adjust-numeric-range.R new file mode 100644 index 0000000..6e04c75 --- /dev/null +++ b/tests/testthat/test-adjust-numeric-range.R @@ -0,0 +1,7 @@ +test_that("adjustment printing", { + expect_snapshot(ctr_reg %>% adjust_numeric_range()) + expect_snapshot(ctr_reg %>% adjust_numeric_range(hardhat::tune())) + expect_snapshot(ctr_reg %>% adjust_numeric_range(-1, hardhat::tune())) + expect_snapshot(ctr_reg %>% adjust_numeric_range(hardhat::tune(), 1)) +}) + diff --git a/tests/testthat/test-adjust-predictions-custom.R b/tests/testthat/test-adjust-predictions-custom.R new file mode 100644 index 0000000..70c6359 --- /dev/null +++ b/tests/testthat/test-adjust-predictions-custom.R @@ -0,0 +1,3 @@ +test_that("adjustment printing", { + expect_snapshot(ctr_cls %>% adjust_predictions_custom()) +}) diff --git a/tests/testthat/test-adjust-probability-calibration.R b/tests/testthat/test-adjust-probability-calibration.R new file mode 100644 index 0000000..193ef1d --- /dev/null +++ b/tests/testthat/test-adjust-probability-calibration.R @@ -0,0 +1,16 @@ +test_that("adjustment printing", { + dummy_cls_cal <- structure(list(), class = "cal_binary") + expect_snapshot(ctr_cls %>% adjust_probability_calibration(dummy_cls_cal)) +}) + +test_that("errors informatively with bad input", { + # check for `adjust_probably_calibration(container)` is in `utils.R` tests + + expect_snapshot(error = TRUE, adjust_probability_calibration(ctr_cls)) + expect_snapshot(error = TRUE, adjust_probability_calibration(ctr_cls, "boop")) + dummy_reg_cal <- structure(list(), class = "cal_regression") + expect_snapshot( + error = TRUE, + adjust_probability_calibration(ctr_cls, dummy_reg_cal) + ) +}) diff --git a/tests/testthat/test-adjust-probability-threshold.R b/tests/testthat/test-adjust-probability-threshold.R new file mode 100644 index 0000000..38ae223 --- /dev/null +++ b/tests/testthat/test-adjust-probability-threshold.R @@ -0,0 +1,4 @@ +test_that("adjustment printing", { + expect_snapshot(ctr_cls %>% adjust_probability_threshold()) + expect_snapshot(ctr_cls %>% adjust_probability_threshold(hardhat::tune())) +}) diff --git a/tests/testthat/test-container.R b/tests/testthat/test-container.R new file mode 100644 index 0000000..35837a5 --- /dev/null +++ b/tests/testthat/test-container.R @@ -0,0 +1,13 @@ +test_that("container printing", { + expect_snapshot(ctr_cls) + expect_snapshot(container(mode = "classification", type = "binary")) + expect_snapshot( + container(mode = "classification", type = "binary") %>% + adjust_probability_threshold(.2) + ) + expect_snapshot( + container(mode = "classification", type = "binary") %>% + adjust_probability_threshold(.2) %>% + adjust_equivocal_zone() + ) +}) diff --git a/tests/testthat/test-placeholder.R b/tests/testthat/test-placeholder.R deleted file mode 100644 index 8849056..0000000 --- a/tests/testthat/test-placeholder.R +++ /dev/null @@ -1,3 +0,0 @@ -test_that("multiplication works", { - expect_equal(2 * 2, 4) -}) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R new file mode 100644 index 0000000..41b093a --- /dev/null +++ b/tests/testthat/test-utils.R @@ -0,0 +1,4 @@ +test_that("check_container raises informative error", { + expect_snapshot(error = TRUE, adjust_probability_threshold("boop")) + expect_no_condition(ctr_cls %>% adjust_probability_threshold(.5)) +}) diff --git a/tests/testthat/test-validation-rules.R b/tests/testthat/test-validation-rules.R new file mode 100644 index 0000000..c86cb1b --- /dev/null +++ b/tests/testthat/test-validation-rules.R @@ -0,0 +1,82 @@ +test_that("validation of operations (regression)", { + dummy_reg_cal <- list() + class(dummy_reg_cal) <- "cal_regression" + + expect_silent( + reg_ctr <- + container(mode = "regression") %>% + adjust_numeric_calibration(dummy_reg_cal) %>% + adjust_numeric_range(lower_limit = 2) %>% + adjust_predictions_custom(squared = .pred^2) + ) + + expect_snapshot( + container(mode = "regression") %>% + adjust_numeric_range(lower_limit = 2) %>% + adjust_numeric_calibration(dummy_reg_cal) %>% + adjust_predictions_custom(squared = .pred^2), + error = TRUE + ) + + # todo should we error if a mutate occurs beforehand? Can we detect if it + # modifies the prediction? + expect_silent( + reg_ctr <- + container(mode = "regression") %>% + adjust_predictions_custom(squared = .pred^2) %>% + adjust_numeric_calibration(dummy_reg_cal) %>% + adjust_numeric_range(lower_limit = 2) + ) +}) + +test_that("validation of operations (classification)", { + dummy_cls_cal <- list() + class(dummy_cls_cal) <- "cal_binary" + + expect_silent( + cls_ctr_1 <- + container(mode = "classification") %>% + adjust_probability_calibration(dummy_cls_cal) %>% + adjust_probability_threshold(threshold = .4) + ) + + expect_silent( + cls_ctr_2 <- + container(mode = "classification") %>% + adjust_predictions_custom(starch = "potato") %>% + adjust_predictions_custom(veg = "green beans") %>% + adjust_probability_calibration(dummy_cls_cal) %>% + adjust_probability_threshold(threshold = .4) + ) + + expect_snapshot( + container(mode = "classification") %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_probability_calibration(dummy_cls_cal), + error = TRUE + ) + + expect_snapshot( + container(mode = "classification") %>% + adjust_predictions_custom(veg = "potato") %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_probability_calibration(dummy_cls_cal), + error = TRUE + ) + + expect_snapshot( + container(mode = "classification") %>% + adjust_predictions_custom(veg = "potato") %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_probability_threshold(threshold = .5) %>% + adjust_probability_calibration(dummy_cls_cal), + error = TRUE + ) + + expect_snapshot( + container(mode = "classification") %>% + adjust_equivocal_zone(value = .2) %>% + adjust_probability_threshold(threshold = .4), + error = TRUE + ) +})