Skip to content

Commit

Permalink
inherit EQ threshold from adjust_probability_threshold() (closes #6)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Dec 11, 2024
1 parent b6ddea4 commit 996fa03
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
29 changes: 25 additions & 4 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The
#' value is the size of the buffer around the threshold.
#' @param threshold A numeric value (between zero and one) or [hardhat::tune()].
#' Defaults to `adjust_probability_threshold(threshold)` if previously set
#' in `x`, or `1 / 2` if not.
#'
#' @section Data Usage:
#' This adjustment doesn't require estimation and, as such, the same data that's
Expand Down Expand Up @@ -56,16 +58,14 @@
#' # adjust hard class predictions
#' predict(tlr_fit, two_class_example) %>% count(predicted)
#' @export
adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
adjust_equivocal_zone <- function(x, value = 0.1, threshold = NULL) {
validate_probably_available()

check_tailor(x)
threshold <- infer_threshold(x = x, threshold = threshold)
if (!is_tune(value)) {
check_number_decimal(value, min = 0, max = 1 / 2)
}
if (!is_tune(threshold)) {
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10)
}

adj <-
new_adjustment(
Expand Down Expand Up @@ -151,3 +151,24 @@ tunable.equivocal_zone <- function(x, ...) {
component_id = "equivocal_zone"
)
}

infer_threshold <- function(x, threshold, call = caller_env()) {
if (!is.null(threshold) && !is_tune(threshold)) {
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10, call = call)
return(threshold)
}

if (is_tune(threshold)) {
return(threshold)
}

# use `map() %>% unlist()` rather than `map_dbl` to handle NULLs
thresholds <- purrr::map(x$adjustments, purrr::pluck, "arguments", "threshold")
thresholds <- unlist(thresholds)

if (!is.null(thresholds)) {
return(thresholds[length(thresholds)])
}

1 / 2
}
6 changes: 4 additions & 2 deletions man/adjust_equivocal_zone.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions tests/testthat/test-adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,37 @@ test_that("tunable", {
c("name", "call_info", "source", "component", "component_id")
)
})

test_that("adjust_equivocal_zone inherits previously set threshold", {
# previously set
tlr <-
tailor() %>%
adjust_probability_threshold(threshold = .4) %>%
adjust_equivocal_zone(value = .2)

expect_equal(tlr$adjustments[[2]]$arguments$threshold, .4)

# not previously set, defualts to 1 / 2
tlr <-
tailor() %>%
adjust_equivocal_zone(value = .2)

expect_equal(tlr$adjustments[[1]]$arguments$threshold, .5)

# previously set, among other things
tlr <-
tailor() %>%
adjust_predictions_custom(.pred = identity(.pred)) %>%
adjust_probability_threshold(threshold = .4) %>%
adjust_equivocal_zone(value = .2)

expect_equal(tlr$adjustments[[3]]$arguments$threshold, .4)

# not previously set, but other stuff happened
tlr <-
tailor() %>%
adjust_predictions_custom(.pred = identity(.pred)) %>%
adjust_equivocal_zone(value = .2)

expect_equal(tlr$adjustments[[2]]$arguments$threshold, .5)
})

0 comments on commit 996fa03

Please sign in to comment.