From 996fa0387a93a22b49dcb1e02f776903f2daf41c Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Wed, 11 Dec 2024 09:55:34 -0600 Subject: [PATCH] inherit EQ threshold from `adjust_probability_threshold()` (closes #6) --- R/adjust-equivocal-zone.R | 29 +++++++++++++++--- man/adjust_equivocal_zone.Rd | 6 ++-- tests/testthat/test-adjust-equivocal-zone.R | 34 +++++++++++++++++++++ 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/R/adjust-equivocal-zone.R b/R/adjust-equivocal-zone.R index 829fade..5286e91 100644 --- a/R/adjust-equivocal-zone.R +++ b/R/adjust-equivocal-zone.R @@ -9,6 +9,8 @@ #' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The #' value is the size of the buffer around the threshold. #' @param threshold A numeric value (between zero and one) or [hardhat::tune()]. +#' Defaults to `adjust_probability_threshold(threshold)` if previously set +#' in `x`, or `1 / 2` if not. #' #' @section Data Usage: #' This adjustment doesn't require estimation and, as such, the same data that's @@ -56,16 +58,14 @@ #' # adjust hard class predictions #' predict(tlr_fit, two_class_example) %>% count(predicted) #' @export -adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) { +adjust_equivocal_zone <- function(x, value = 0.1, threshold = NULL) { validate_probably_available() check_tailor(x) + threshold <- infer_threshold(x = x, threshold = threshold) if (!is_tune(value)) { check_number_decimal(value, min = 0, max = 1 / 2) } - if (!is_tune(threshold)) { - check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10) - } adj <- new_adjustment( @@ -151,3 +151,24 @@ tunable.equivocal_zone <- function(x, ...) { component_id = "equivocal_zone" ) } + +infer_threshold <- function(x, threshold, call = caller_env()) { + if (!is.null(threshold) && !is_tune(threshold)) { + check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10, call = call) + return(threshold) + } + + if (is_tune(threshold)) { + return(threshold) + } + + # use `map() %>% unlist()` rather than `map_dbl` to handle NULLs + thresholds <- purrr::map(x$adjustments, purrr::pluck, "arguments", "threshold") + thresholds <- unlist(thresholds) + + if (!is.null(thresholds)) { + return(thresholds[length(thresholds)]) + } + + 1 / 2 +} diff --git a/man/adjust_equivocal_zone.Rd b/man/adjust_equivocal_zone.Rd index 42f5476..abb554b 100644 --- a/man/adjust_equivocal_zone.Rd +++ b/man/adjust_equivocal_zone.Rd @@ -4,7 +4,7 @@ \alias{adjust_equivocal_zone} \title{Apply an equivocal zone to a binary classification model.} \usage{ -adjust_equivocal_zone(x, value = 0.1, threshold = 1/2) +adjust_equivocal_zone(x, value = 0.1, threshold = NULL) } \arguments{ \item{x}{A \code{\link[=tailor]{tailor()}}.} @@ -12,7 +12,9 @@ adjust_equivocal_zone(x, value = 0.1, threshold = 1/2) \item{value}{A numeric value (between zero and 1/2) or \code{\link[hardhat:tune]{hardhat::tune()}}. The value is the size of the buffer around the threshold.} -\item{threshold}{A numeric value (between zero and one) or \code{\link[hardhat:tune]{hardhat::tune()}}.} +\item{threshold}{A numeric value (between zero and one) or \code{\link[hardhat:tune]{hardhat::tune()}}. +Defaults to \code{adjust_probability_threshold(threshold)} if previously set +in \code{x}, or \code{1 / 2} if not.} } \description{ Equivocal zones describe intervals of predicted probabilities that are deemed diff --git a/tests/testthat/test-adjust-equivocal-zone.R b/tests/testthat/test-adjust-equivocal-zone.R index 5bae335..e2d0835 100644 --- a/tests/testthat/test-adjust-equivocal-zone.R +++ b/tests/testthat/test-adjust-equivocal-zone.R @@ -72,3 +72,37 @@ test_that("tunable", { c("name", "call_info", "source", "component", "component_id") ) }) + +test_that("adjust_equivocal_zone inherits previously set threshold", { + # previously set + tlr <- + tailor() %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_equivocal_zone(value = .2) + + expect_equal(tlr$adjustments[[2]]$arguments$threshold, .4) + + # not previously set, defualts to 1 / 2 + tlr <- + tailor() %>% + adjust_equivocal_zone(value = .2) + + expect_equal(tlr$adjustments[[1]]$arguments$threshold, .5) + + # previously set, among other things + tlr <- + tailor() %>% + adjust_predictions_custom(.pred = identity(.pred)) %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_equivocal_zone(value = .2) + + expect_equal(tlr$adjustments[[3]]$arguments$threshold, .4) + + # not previously set, but other stuff happened + tlr <- + tailor() %>% + adjust_predictions_custom(.pred = identity(.pred)) %>% + adjust_equivocal_zone(value = .2) + + expect_equal(tlr$adjustments[[2]]$arguments$threshold, .5) +})