From 9383acdca9c19aadeb2762d353bb3dbf096c5eb5 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Tue, 17 Sep 2024 13:49:34 -0500 Subject: [PATCH] make `set_tailor_type()` play nice with `infer_type()` (closes #38) --- R/tailor.R | 8 +++-- R/utils.R | 21 +++++++++++++ tests/testthat/_snaps/utils.md | 36 ++++++++++++++++++++++ tests/testthat/test-utils.R | 55 ++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 2 deletions(-) diff --git a/R/tailor.R b/R/tailor.R index bb69383..5b67b54 100644 --- a/R/tailor.R +++ b/R/tailor.R @@ -209,8 +209,9 @@ predict.tailor <- function(object, new_data, ...) { new_data } -set_tailor_type <- function(object, y) { +set_tailor_type <- function(object, y, call = caller_env()) { if (object$type != "unknown") { + check_outcome_type(y, object$type, call = call) return(object) } if (is.factor(y)) { @@ -223,7 +224,10 @@ set_tailor_type <- function(object, y) { } else if (is.numeric(y)) { object$type <- "regression" } else { - cli_abort("Only factor and numeric outcomes are currently supported.") + cli_abort( + "Only factor and numeric outcomes are currently supported.", + call = call + ) } object } diff --git a/R/utils.R b/R/utils.R index 483b4a0..4f934c9 100644 --- a/R/utils.R +++ b/R/utils.R @@ -193,6 +193,27 @@ check_method <- function(method, method } +# at `fit()` time, we check the type of the outcome vs the type +# supported by the applied adjustments. where this is called currently, +# we know already that `type` is not "unknown" +check_outcome_type <- function(outcome, type, call) { + outcome_is_compatible <- + switch( + type, + regression = is.numeric(outcome), + binary = , multiclass = is.factor(outcome), + FALSE + ) + + if (!outcome_is_compatible) { + cli_abort( + "Tailors with {type} adjustments are not compatible + with {.cls {class(outcome)}} outcomes.", + call = call + ) + } +} + check_selection <- function(selector, result, arg, call = caller_env()) { if (length(result) == 0) { cli_abort( diff --git a/tests/testthat/_snaps/utils.md b/tests/testthat/_snaps/utils.md index 2f4aaf8..a655d80 100644 --- a/tests/testthat/_snaps/utils.md +++ b/tests/testthat/_snaps/utils.md @@ -6,3 +6,39 @@ Error in `adjust_probability_threshold()`: ! `x` should be a (`?tailor::tailor()`), not a string. +# fit.tailor() errors informatively with incompatible outcome + + Code + fit(tailor() %>% adjust_probability_threshold(0.1), two_class_example, outcome = c( + test_numeric), estimate = c(predicted), probabilities = c(Class1, Class2)) + Condition + Error in `fit()`: + ! Tailors with binary adjustments are not compatible with outcomes. + +--- + + Code + fit(tailor() %>% adjust_numeric_range(lower_limit = 0.1), two_class_example, + outcome = c(truth), estimate = c(Class1)) + Condition + Error in `fit()`: + ! Tailors with regression adjustments are not compatible with outcomes. + +--- + + Code + fit(tailor() %>% adjust_probability_threshold(0.1), two_class_example, outcome = c( + test_date), estimate = c(predicted), probabilities = c(Class1, Class2)) + Condition + Error in `fit()`: + ! Tailors with binary adjustments are not compatible with outcomes. + +--- + + Code + fit(tailor() %>% adjust_predictions_custom(hey = "there"), two_class_example, + outcome = c(test_date), estimate = c(predicted), probabilities = c(Class1)) + Condition + Error in `fit()`: + ! Only factor and numeric outcomes are currently supported. + diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index e950aee..722f262 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -73,3 +73,58 @@ test_that("tailor_requires_fit works", { ) ) }) + +test_that("fit.tailor() errors informatively with incompatible outcome", { + skip_if_not_installed("modeldata") + library(modeldata) + + two_class_example$test_numeric <- two_class_example$Class1 + 1 + two_class_example$test_date <- as.POSIXct(two_class_example$Class1) + + # supply a numeric outcome to a binary tailor + expect_snapshot( + error = TRUE, + fit( + tailor() %>% adjust_probability_threshold(.1), + two_class_example, + outcome = c(test_numeric), + estimate = c(predicted), + probabilities = c(Class1, Class2) + ) + ) + + # supply a factor outcome to a regression tailor + expect_snapshot( + error = TRUE, + fit( + tailor() %>% adjust_numeric_range(lower_limit = .1), + two_class_example, + outcome = c(truth), + estimate = c(Class1) + ) + ) + + # supply a totally wild outcome to a regression tailor + expect_snapshot( + error = TRUE, + fit( + tailor() %>% adjust_probability_threshold(.1), + two_class_example, + outcome = c(test_date), + estimate = c(predicted), + probabilities = c(Class1, Class2) + ) + ) + + # supply a totally wild outcome to an unknown tailor + expect_snapshot( + error = TRUE, + fit( + tailor() %>% adjust_predictions_custom(hey = "there"), + two_class_example, + outcome = c(test_date), + estimate = c(predicted), + probabilities = c(Class1) + ) + ) +})