Skip to content

Commit

Permalink
add tailor_fully_trained
Browse files Browse the repository at this point in the history
can be imported from workflows and tune
  • Loading branch information
simonpcouch committed May 23, 2024
1 parent d1283f3 commit fe07602
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export(extract_parameter_set_dials)
export(fit)
export(required_pkgs)
export(tailor)
export(tailor_fully_trained)
export(tidy)
export(tunable)
export(tune_args)
Expand Down
24 changes: 24 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
#' Internal tailor functions
#'
#' Utilities for use in downstream packages.
#'
#' @keywords internal
#' @name tailor-internals
NULL


is_tune <- function(x) {
if (!is.call(x)) {
return(FALSE)
Expand Down Expand Up @@ -48,6 +57,21 @@ is_tailor <- function(x) {
inherits(x, "tailor")
}

#' @export
#' @keywords internal
#' @rdname tailor-internals
tailor_fully_trained <- function(x) {
if (length(x$operations) == 0L) {
return(FALSE)
}

all(purrr::map_lgl(x$operations, tailor_operation_trained))
}

tailor_operation_trained <- function(x) {
isTRUE(x$trained)
}

# ad-hoc checking --------------------------------------------------------------
check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
if (!is_tailor(x)) {
Expand Down
13 changes: 13 additions & 0 deletions man/tailor-internals.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,51 @@ test_that("check_tailor raises informative error", {
expect_snapshot(error = TRUE, adjust_probability_threshold("boop"))
expect_no_condition(tailor() %>% adjust_probability_threshold(.5))
})

test_that("tailor_fully_trained works", {
skip_if_not_installed("modeldata")
data("two_class_example", package = "modeldata")
expect_false(tailor_fully_trained(tailor()))
expect_false(
tailor_fully_trained(tailor() %>% adjust_probability_threshold(.5))
)
expect_false(
tailor_fully_trained(
tailor() %>%
adjust_probability_calibration("logistic") %>%
fit(
two_class_example,
outcome = "truth",
estimate = tidyselect::contains("Class")
) %>%
adjust_probability_threshold(.5)
)
)

expect_true(
tailor_fully_trained(
tailor() %>%
adjust_probability_calibration("logistic") %>%
fit(
two_class_example,
outcome = "truth",
# todo: this function requires a different format of `estimate`
# and `probabilities` specification than the call below to
# be able to fit properly.
estimate = tidyselect::contains("Class")
)
)
)
expect_true(
tailor_fully_trained(
tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = "truth",
estimate = "predicted",
probabilities = tidyselect::contains("Class")
)
)
)
})

0 comments on commit fe07602

Please sign in to comment.