Skip to content

Commit

Permalink
make set_tailor_type() play nice with infer_type() (closes #38)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Sep 17, 2024
1 parent d80ef2a commit 9383acd
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 2 deletions.
8 changes: 6 additions & 2 deletions R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ predict.tailor <- function(object, new_data, ...) {
new_data
}

set_tailor_type <- function(object, y) {
set_tailor_type <- function(object, y, call = caller_env()) {
if (object$type != "unknown") {
check_outcome_type(y, object$type, call = call)
return(object)
}
if (is.factor(y)) {
Expand All @@ -223,7 +224,10 @@ set_tailor_type <- function(object, y) {
} else if (is.numeric(y)) {
object$type <- "regression"
} else {
cli_abort("Only factor and numeric outcomes are currently supported.")
cli_abort(
"Only factor and numeric outcomes are currently supported.",
call = call
)
}
object
}
Expand Down
21 changes: 21 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ check_method <- function(method,
method
}

# at `fit()` time, we check the type of the outcome vs the type
# supported by the applied adjustments. where this is called currently,
# we know already that `type` is not "unknown"
check_outcome_type <- function(outcome, type, call) {
outcome_is_compatible <-
switch(
type,
regression = is.numeric(outcome),
binary = , multiclass = is.factor(outcome),
FALSE
)

if (!outcome_is_compatible) {
cli_abort(
"Tailors with {type} adjustments are not compatible
with {.cls {class(outcome)}} outcomes.",
call = call
)
}
}

check_selection <- function(selector, result, arg, call = caller_env()) {
if (length(result) == 0) {
cli_abort(
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/_snaps/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,39 @@
Error in `adjust_probability_threshold()`:
! `x` should be a <tailor> (`?tailor::tailor()`), not a string.

# fit.tailor() errors informatively with incompatible outcome

Code
fit(tailor() %>% adjust_probability_threshold(0.1), two_class_example, outcome = c(
test_numeric), estimate = c(predicted), probabilities = c(Class1, Class2))
Condition
Error in `fit()`:
! Tailors with binary adjustments are not compatible with <numeric> outcomes.

---

Code
fit(tailor() %>% adjust_numeric_range(lower_limit = 0.1), two_class_example,
outcome = c(truth), estimate = c(Class1))
Condition
Error in `fit()`:
! Tailors with regression adjustments are not compatible with <factor> outcomes.

---

Code
fit(tailor() %>% adjust_probability_threshold(0.1), two_class_example, outcome = c(
test_date), estimate = c(predicted), probabilities = c(Class1, Class2))
Condition
Error in `fit()`:
! Tailors with binary adjustments are not compatible with <POSIXct/POSIXt> outcomes.

---

Code
fit(tailor() %>% adjust_predictions_custom(hey = "there"), two_class_example,
outcome = c(test_date), estimate = c(predicted), probabilities = c(Class1))
Condition
Error in `fit()`:
! Only factor and numeric outcomes are currently supported.

55 changes: 55 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,58 @@ test_that("tailor_requires_fit works", {
)
)
})

test_that("fit.tailor() errors informatively with incompatible outcome", {
skip_if_not_installed("modeldata")
library(modeldata)

two_class_example$test_numeric <- two_class_example$Class1 + 1
two_class_example$test_date <- as.POSIXct(two_class_example$Class1)

# supply a numeric outcome to a binary tailor
expect_snapshot(
error = TRUE,
fit(
tailor() %>% adjust_probability_threshold(.1),
two_class_example,
outcome = c(test_numeric),
estimate = c(predicted),
probabilities = c(Class1, Class2)
)
)

# supply a factor outcome to a regression tailor
expect_snapshot(
error = TRUE,
fit(
tailor() %>% adjust_numeric_range(lower_limit = .1),
two_class_example,
outcome = c(truth),
estimate = c(Class1)
)
)

# supply a totally wild outcome to a regression tailor
expect_snapshot(
error = TRUE,
fit(
tailor() %>% adjust_probability_threshold(.1),
two_class_example,
outcome = c(test_date),
estimate = c(predicted),
probabilities = c(Class1, Class2)
)
)

# supply a totally wild outcome to an unknown tailor
expect_snapshot(
error = TRUE,
fit(
tailor() %>% adjust_predictions_custom(hey = "there"),
two_class_example,
outcome = c(test_date),
estimate = c(predicted),
probabilities = c(Class1)
)
)
})

0 comments on commit 9383acd

Please sign in to comment.