diff --git a/DESCRIPTION b/DESCRIPTION index b76eb5d..7eafe80 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,12 +27,15 @@ Imports: purrr, rlang (>= 1.1.0), tibble, - tidyselect + tidyselect, + vctrs Suggests: + dials, modeldata, testthat (>= 3.0.0), workflows Remotes: + tidymodels/dials#358, tidymodels/probably, tidymodels/workflows Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index d0e9232..94baa69 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +S3method(extract_parameter_dials,tailor) +S3method(extract_parameter_set_dials,tailor) S3method(fit,equivocal_zone) S3method(fit,numeric_calibration) S3method(fit,numeric_range) @@ -33,6 +35,9 @@ S3method(tunable,numeric_range) S3method(tunable,predictions_custom) S3method(tunable,probability_calibration) S3method(tunable,probability_threshold) +S3method(tunable,tailor) +S3method(tune_args,adjustment) +S3method(tune_args,tailor) export("%>%") export(adjust_equivocal_zone) export(adjust_numeric_calibration) diff --git a/R/adjust-equivocal-zone.R b/R/adjust-equivocal-zone.R index cede76d..d983e1b 100644 --- a/R/adjust-equivocal-zone.R +++ b/R/adjust-equivocal-zone.R @@ -128,16 +128,15 @@ required_pkgs.equivocal_zone <- function(x, ...) { #' @export tunable.equivocal_zone <- function(x, ...) { - tibble::new_tibble(list( + tibble::tibble( name = "buffer", call_info = list(list(pkg = "dials", fun = "buffer")), source = "tailor", component = "equivocal_zone", component_id = "equivocal_zone" - )) + ) } # todo missing methods: -# todo tune_args # todo tidy # todo extract_parameter_set_dials diff --git a/R/adjust-numeric-calibration.R b/R/adjust-numeric-calibration.R index 8170961..ab2d3d3 100644 --- a/R/adjust-numeric-calibration.R +++ b/R/adjust-numeric-calibration.R @@ -126,6 +126,5 @@ tunable.numeric_calibration <- function(x, ...) { } # todo missing methods: -# todo tune_args # todo tidy # todo extract_parameter_set_dials diff --git a/R/adjust-numeric-range.R b/R/adjust-numeric-range.R index c1fa60a..f9c9daa 100644 --- a/R/adjust-numeric-range.R +++ b/R/adjust-numeric-range.R @@ -129,7 +129,7 @@ required_pkgs.numeric_range <- function(x, ...) { #' @export tunable.numeric_range <- function(x, ...) { - tibble::new_tibble(list( + tibble::tibble( name = c("lower_limit", "upper_limit"), call_info = list( list(pkg = "dials", fun = "lower_limit"), # todo make these dials functions @@ -138,10 +138,9 @@ tunable.numeric_range <- function(x, ...) { source = "tailor", component = "numeric_range", component_id = "numeric_range" - )) + ) } # todo missing methods: -# todo tune_args # todo tidy # todo extract_parameter_set_dials diff --git a/R/adjust-predictions-custom.R b/R/adjust-predictions-custom.R index 5ca5df0..5282484 100644 --- a/R/adjust-predictions-custom.R +++ b/R/adjust-predictions-custom.R @@ -92,6 +92,5 @@ tunable.predictions_custom <- function(x, ...) { } # todo missing methods: -# todo tune_args # todo tidy # todo extract_parameter_set_dials diff --git a/R/adjust-probability-calibration.R b/R/adjust-probability-calibration.R index 38fc04e..f36e407 100644 --- a/R/adjust-probability-calibration.R +++ b/R/adjust-probability-calibration.R @@ -134,6 +134,5 @@ tunable.probability_calibration <- function(x, ...) { } # todo missing methods: -# todo tune_args # todo tidy # todo extract_parameter_set_dials diff --git a/R/adjust-probability-threshold.R b/R/adjust-probability-threshold.R index 2ee095e..49743a9 100644 --- a/R/adjust-probability-threshold.R +++ b/R/adjust-probability-threshold.R @@ -113,16 +113,15 @@ required_pkgs.probability_threshold <- function(x, ...) { #' @export tunable.probability_threshold <- function(x, ...) { - tibble::new_tibble(list( + tibble::tibble( name = "threshold", call_info = list(list(pkg = "dials", fun = "threshold")), source = "tailor", component = "probability_threshold", component_id = "probability_threshold" - )) + ) } # todo missing methods: -# todo tune_args # todo tidy # todo extract_parameter_set_dials diff --git a/R/extract.R b/R/extract.R new file mode 100644 index 0000000..ecac918 --- /dev/null +++ b/R/extract.R @@ -0,0 +1,47 @@ +#' @export +extract_parameter_set_dials.tailor <- function(x, ...) { + all_args <- generics::tunable(x) + tuning_param <- generics::tune_args(x) + res <- + dplyr::inner_join( + tuning_param %>% dplyr::select(-tunable), + all_args, + by = c("name", "source", "component", "component_id") + ) %>% + dplyr::mutate(object = purrr::map(call_info, eval_call_info)) + + dials::parameters_constr( + res$name, + res$id, + res$source, + res$component, + res$component_id, + res$object + ) +} + +eval_call_info <- function(x) { + if (!is.null(x)) { + # Look for other options + allowed_opts <- c("range", "trans", "values") + if (any(names(x) %in% allowed_opts)) { + opts <- x[names(x) %in% allowed_opts] + } else { + opts <- list() + } + res <- try(rlang::eval_tidy(rlang::call2(x$fun, .ns = x$pkg, !!!opts)), silent = TRUE) + if (inherits(res, "try-error")) { + cli::cli_abort( + "Error when calling {.fn {x$fun}}: {as.character(res)}" + ) + } + } else { + res <- NA + } + res +} + +#' @export +extract_parameter_dials.tailor <- function(x, parameter, ...) { + extract_parameter_dials(extract_parameter_set_dials(x), parameter) +} diff --git a/R/tailor.R b/R/tailor.R index 5b67b54..c90f8ab 100644 --- a/R/tailor.R +++ b/R/tailor.R @@ -235,6 +235,42 @@ set_tailor_type <- function(object, y, call = caller_env()) { # todo: where to validate #levels? # todo setup eval_time # todo missing methods: -# todo tune_args + +#' @export +tune_args.tailor <- function(object, full = FALSE, ...) { + adjustments <- object$adjustments + + if (length(adjustments) == 0L) { + return(tune_tbl()) + } + + res <- purrr::map(object$adjustments, tune_args, full = full) + res <- purrr::list_rbind(res) + + tune_tbl( + res$name, + res$tunable, + res$id, + res$source, + res$component, + res$component_id, + full = full + ) +} + +#' @export +tunable.tailor <- function(x, ...) { + if (length(x$adjustments) == 0) { + res <- no_param + } else { + res <- purrr::map(x$adjustments, tunable) + res <- vctrs::vec_rbind(!!!res) + if (nrow(res) > 0) { + res <- res[!is.na(res$name), ] + } + } + res +} + # todo tidy (this should probably just be `adjustment_orderings()`) # todo extract_parameter_set_dials diff --git a/R/utils.R b/R/utils.R index 4f934c9..0a6b3f6 100644 --- a/R/utils.R +++ b/R/utils.R @@ -6,7 +6,7 @@ #' @name tailor-internals NULL - +# tuning machinery ------------------------------------------------------------- is_tune <- function(x) { if (!is.call(x)) { return(FALSE) @@ -14,8 +14,62 @@ is_tune <- function(x) { isTRUE(identical(quote(tune), x[[1]])) } -# for adjustments with no tunable parameters +tune_tbl <- function(name = character(), tunable = logical(), id = character(), + source = character(), component = character(), + component_id = character(), full = FALSE, call = caller_env()) { + complete_id <- id[!is.na(id)] + dups <- duplicated(complete_id) + + if (any(dups)) { + offenders <- unique(complete_id[dups]) + cli::cli_abort( + "{.val {offenders}} {?has a/have} duplicate {.field id} value{?s}.", + call = call + ) + } + + tune_tbl <- + tibble::tibble( + name = as.character(name), + tunable = as.logical(tunable), + id = as.character(id), + source = as.character(source), + component = as.character(component), + component_id = as.character(component_id) + ) + + if (!full) { + tune_tbl <- tune_tbl[tune_tbl$tunable, ] + } + + tune_tbl +} +#' @export +tune_args.adjustment <- function(object, full = FALSE, ...) { + # Grab the adjustment class before the subset, as that removes the class + adjustment_type <- class(object)[1] + + tune_param_list <- tunable(object)$name + + # remove the non-tunable arguments as they are not important + adjustment_args <- object$arguments[tune_param_list] + + res <- purrr::map_chr(adjustment_args, find_tune_id) + res <- ifelse(res == "", names(res), res) + + tune_tbl( + name = names(res), + tunable = unname(!is.na(res)), + id = unname(res), + source = "tailor", + component = adjustment_type, + component_id = adjustment_type, + full = full + ) +} + +# for adjustments with no tunable parameters no_param <- tibble::tibble( name = character(0), @@ -25,6 +79,87 @@ no_param <- component_id = character(0) ) +find_tune_id <- function(x, arg = caller_arg(x), call = caller_env()) { + if (length(x) == 0L) { + return(NA_character_) + } + if (rlang::is_quosures(x)) { + .x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE) + if (inherits(.x, "try-error")) { + x <- purrr::map(x, rlang::quo_get_expr) + } else { + x <- .x + } + } + id <- tune_id(x, call = call) + + if (!is.na(id)) { + return(id) + } + + if (is.atomic(x) | is.name(x) | length(x) == 1) { + return(NA_character_) + } + + tunable_elems <- vector("character", length = length(x)) + for (i in seq_along(x)) { + tunable_elems[i] <- find_tune_id(x[[i]], call = call) + } + tunable_elems <- tunable_elems[!is.na(tunable_elems)] + + if (length(tunable_elems) == 0) { + tunable_elems <- NA_character_ + } + + if (sum(tunable_elems == "", na.rm = TRUE) > 1) { + offenders <- paste0(deparse(x), collapse = "") + cli::cli_abort( + c( + "Only one tunable value is currently allowed per argument.", + "{.arg {arg}} has {.code {offenders}}." + ), + call = call + ) + } + + return(tunable_elems) +} + +tune_id <- function(x, call = caller_env()) { + if (is.null(x)) { + return(NA_character_) + } else { + if (rlang::is_quosures(x)) { + .x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE) + if (inherits(.x, "try-error")) { + x <- purrr::map(x, rlang::quo_get_expr) + } else { + x <- .x + } + if (is.null(x)) { + return(NA_character_) + } + } + + if (is.call(x)) { + if (rlang::call_name(x) == "tune") { + if (length(x) > 1) { + return(x[[2]]) + } else { + return("") + } + + return(x$id) + } else { + return(NA_character_) + } + } + } + + NA_character_ +} + +# new_adjustment ------------------------------------------------------------- # These values are used to specify "what will we need for the adjustment?" and # "what will we change?". For the outputs, we cannot change the probabilities # without changing the classes. This is important because we are going to have diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md new file mode 100644 index 0000000..2d08c9b --- /dev/null +++ b/tests/testthat/_snaps/extract.md @@ -0,0 +1,16 @@ +# extract single parameter from tailor with no adjustments + + Code + extract_parameter_dials(tailor(), parameter = "none there") + Condition + Error in `extract_parameter_dials()`: + ! No parameter exists with id "none there". + +# extract single parameter from tailor with no tunable parameters + + Code + extract_parameter_dials(tlr, parameter = "none there") + Condition + Error in `extract_parameter_dials()`: + ! No parameter exists with id "none there". + diff --git a/tests/testthat/_snaps/utils.md b/tests/testthat/_snaps/utils.md index a655d80..963fdb4 100644 --- a/tests/testthat/_snaps/utils.md +++ b/tests/testthat/_snaps/utils.md @@ -42,3 +42,12 @@ Error in `fit()`: ! Only factor and numeric outcomes are currently supported. +# find_tune_id() works + + Code + find_tune_id(x) + Condition + Error: + ! Only one tunable value is currently allowed per argument. + `x` has `list(a = tune(), b = tune())`. + diff --git a/tests/testthat/helper-extract_parameter_set.R b/tests/testthat/helper-extract_parameter_set.R new file mode 100644 index 0000000..d88f1d6 --- /dev/null +++ b/tests/testthat/helper-extract_parameter_set.R @@ -0,0 +1,18 @@ +check_parameter_set_tibble <- function(x) { + expect_equal(names(x), c("name", "id", "source", "component", "component_id", "object")) + expect_equal(class(x$name), "character") + expect_equal(class(x$id), "character") + expect_equal(class(x$source), "character") + expect_equal(class(x$component), "character") + expect_equal(class(x$component_id), "character") + expect_true(!any(duplicated(x$id))) + + expect_equal(class(x$object), "list") + obj_check <- purrr::map_lgl( + x$object, + function(.x) inherits(.x, "param") | all(is.na(.x)) + ) + expect_true(all(obj_check)) + + invisible(TRUE) +} diff --git a/tests/testthat/test-adjust-equivocal-zone.R b/tests/testthat/test-adjust-equivocal-zone.R index f4959d7..c2a5795 100644 --- a/tests/testthat/test-adjust-equivocal-zone.R +++ b/tests/testthat/test-adjust-equivocal-zone.R @@ -55,3 +55,18 @@ test_that("adjustment printing", { expect_snapshot(tailor() %>% adjust_equivocal_zone()) expect_snapshot(tailor() %>% adjust_equivocal_zone(hardhat::tune())) }) + +test_that("tunable", { + tlr <- + tailor() %>% + adjust_equivocal_zone(value = 1 / 4) + adj_param <- tunable(tlr$adjustments[[1]]) + expect_equal(adj_param$name, c("buffer")) + expect_true(all(adj_param$source == "tailor")) + expect_true(is.list(adj_param$call_info)) + expect_equal(nrow(adj_param), 1) + expect_equal( + names(adj_param), + c("name", "call_info", "source", "component", "component_id") + ) +}) diff --git a/tests/testthat/test-adjust-numeric-calibration.R b/tests/testthat/test-adjust-numeric-calibration.R index 2e4c37d..7532675 100644 --- a/tests/testthat/test-adjust-numeric-calibration.R +++ b/tests/testthat/test-adjust-numeric-calibration.R @@ -82,3 +82,11 @@ test_that("errors informatively with bad input", { expect_no_condition(adjust_numeric_calibration(tailor())) expect_no_condition(adjust_numeric_calibration(tailor(), "linear")) }) + +test_that("tunable", { + tlr <- + tailor() %>% + adjust_numeric_calibration(method = "linear") + adj_param <- tunable(tlr$adjustments[[1]]) + expect_equal(adj_param, no_param) +}) diff --git a/tests/testthat/test-adjust-numeric-range.R b/tests/testthat/test-adjust-numeric-range.R index b5adaee..5318690 100644 --- a/tests/testthat/test-adjust-numeric-range.R +++ b/tests/testthat/test-adjust-numeric-range.R @@ -41,3 +41,17 @@ test_that("adjustment printing", { expect_snapshot(tailor() %>% adjust_numeric_range(hardhat::tune(), 1)) }) +test_that("tunable", { + tlr <- + tailor() %>% + adjust_numeric_range(lower_limit = 1, upper_limit = 2) + adj_param <- tunable(tlr$adjustments[[1]]) + expect_equal(adj_param$name, c("lower_limit", "upper_limit")) + expect_true(all(adj_param$source == "tailor")) + expect_true(is.list(adj_param$call_info)) + expect_equal(nrow(adj_param), 2) + expect_equal( + names(adj_param), + c("name", "call_info", "source", "component", "component_id") + ) +}) diff --git a/tests/testthat/test-adjust-predictions-custom.R b/tests/testthat/test-adjust-predictions-custom.R index f61e745..8643f32 100644 --- a/tests/testthat/test-adjust-predictions-custom.R +++ b/tests/testthat/test-adjust-predictions-custom.R @@ -43,3 +43,11 @@ test_that("basic adjust_predictions_custom() usage works", { test_that("adjustment printing", { expect_snapshot(tailor() %>% adjust_predictions_custom()) }) + +test_that("tunable", { + tlr <- + tailor() %>% + adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2)) + adj_param <- tunable(tlr$adjustments[[1]]) + expect_equal(adj_param, no_param) +}) diff --git a/tests/testthat/test-adjust-probability-calibration.R b/tests/testthat/test-adjust-probability-calibration.R index 0ca01b7..ba4c33f 100644 --- a/tests/testthat/test-adjust-probability-calibration.R +++ b/tests/testthat/test-adjust-probability-calibration.R @@ -101,3 +101,11 @@ test_that("errors informatively with bad input", { expect_no_condition(adjust_numeric_calibration(tailor())) expect_no_condition(adjust_numeric_calibration(tailor(), "linear")) }) + +test_that("tunable", { + tlr <- + tailor() %>% + adjust_probability_calibration(method = "logistic") + adj_param <- tunable(tlr$adjustments[[1]]) + expect_equal(adj_param, no_param) +}) diff --git a/tests/testthat/test-adjust-probability-threshold.R b/tests/testthat/test-adjust-probability-threshold.R index cf533b0..a7c5972 100644 --- a/tests/testthat/test-adjust-probability-threshold.R +++ b/tests/testthat/test-adjust-probability-threshold.R @@ -40,3 +40,18 @@ test_that("adjustment printing", { expect_snapshot(tailor() %>% adjust_probability_threshold()) expect_snapshot(tailor() %>% adjust_probability_threshold(hardhat::tune())) }) + +test_that("tunable", { + tlr <- + tailor() %>% + adjust_probability_threshold(.1) + adj_param <- tunable(tlr$adjustments[[1]]) + expect_equal(adj_param$name, "threshold") + expect_true(all(adj_param$source == "tailor")) + expect_true(is.list(adj_param$call_info)) + expect_equal(nrow(adj_param), 1) + expect_equal( + names(adj_param), + c("name", "call_info", "source", "component", "component_id") + ) +}) diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R new file mode 100644 index 0000000..0c66abe --- /dev/null +++ b/tests/testthat/test-extract.R @@ -0,0 +1,109 @@ +test_that("extract parameter set from tailor with no adjustments", { + skip_if_not_installed("dials") + + bare_tlr <- tailor() + + bare_info <- extract_parameter_set_dials(bare_tlr) + check_parameter_set_tibble(bare_info) + expect_equal(nrow(bare_info), 0) +}) + +test_that("extract parameter set from tailor with no tunable parameters", { + skip_if_not_installed("dials") + + tlr <- + tailor() %>% + adjust_predictions_custom() + + tlr_info <- extract_parameter_set_dials(tlr) + + check_parameter_set_tibble(tlr_info) + expect_equal(nrow(tlr_info), 0) +}) + +test_that("extract parameter set from tailor with a tunable parameter", { + skip_if_not_installed("dials") + + tlr <- + tailor() %>% + adjust_numeric_calibration() %>% + adjust_numeric_range(lower_limit = hardhat::tune()) + + tlr_info <- extract_parameter_set_dials(tlr) + + check_parameter_set_tibble(tlr_info) + expect_equal(nrow(tlr_info), 1) + + expect_equal(tlr_info$component, "numeric_range") + expect_true(all(tlr_info$source == "tailor")) + expect_equal(tlr_info$name, "lower_limit") + expect_equal(tlr_info$id, "lower_limit") + + expect_equal(tlr_info$object[[1]], dials::lower_limit(c(-Inf, Inf))) +}) + +test_that("extract parameter set from tailor with multiple tunable parameters", { + skip_if_not_installed("dials") + + tlr <- + tailor() %>% + adjust_numeric_calibration() %>% + adjust_numeric_range( + lower_limit = hardhat::tune(), + upper_limit = hardhat::tune() + ) + + tlr_info <- extract_parameter_set_dials(tlr) + + check_parameter_set_tibble(tlr_info) + expect_equal(nrow(tlr_info), 2) + + expect_equal(tlr_info$component, rep("numeric_range", 2)) + expect_true(all(tlr_info$source == "tailor")) + expect_equal(tlr_info$name, c("lower_limit", "upper_limit")) + expect_equal(tlr_info$id, c("lower_limit", "upper_limit")) + + expect_equal(tlr_info$object[[1]], dials::lower_limit(c(-Inf, Inf))) + expect_equal(tlr_info$object[[2]], dials::upper_limit(c(-Inf, Inf))) +}) + +# ------------------------------------------------------------------------- + +test_that("extract single parameter from tailor with no adjustments", { + skip_if_not_installed("dials") + + expect_snapshot( + error = TRUE, + extract_parameter_dials(tailor(), parameter = "none there") + ) +}) + +test_that("extract single parameter from tailor with no tunable parameters", { + skip_if_not_installed("dials") + + tlr <- + tailor() %>% + adjust_numeric_calibration() + + expect_snapshot( + error = TRUE, + extract_parameter_dials(tlr, parameter = "none there") + ) +}) + +test_that("extract single parameter from tailor with tunable parameters", { + skip_if_not_installed("dials") + + tlr <- + tailor() %>% + adjust_numeric_calibration() %>% + adjust_numeric_range( + lower_limit = hardhat::tune(), + upper_limit = hardhat::tune() + ) + + expect_equal( + extract_parameter_dials(tlr, "lower_limit"), + dials::lower_limit() + ) +}) diff --git a/tests/testthat/test-tailor.R b/tests/testthat/test-tailor.R index e951e87..8a1e2cd 100644 --- a/tests/testthat/test-tailor.R +++ b/tests/testthat/test-tailor.R @@ -107,3 +107,33 @@ test_that("error informatively with empty tidyselections", { ) ) }) + +test_that("tunable (no adjustments)", { + tlr <- + tailor() + + tlr_param <- tunable(tlr) + expect_equal(tlr_param, no_param) +}) + +test_that("tunable (multiple adjustments)", { + tlr <- + tailor() %>% + adjust_probability_threshold(.2) %>% + adjust_equivocal_zone() + + tlr_param <- tunable(tlr) + expect_equal(tlr_param$name, c("threshold", "buffer")) + expect_true(all(tlr_param$source == "tailor")) + expect_true(is.list(tlr_param$call_info)) + expect_equal(nrow(tlr_param), 2) + expect_equal( + names(tlr_param), + c("name", "call_info", "source", "component", "component_id") + ) + + expect_equal( + tlr_param, + dplyr::bind_rows(tunable(tlr$adjustments[[1]]), tunable(tlr$adjustments[[2]])) + ) +}) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 81df024..7b2d3b0 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -125,3 +125,34 @@ test_that("fit.tailor() errors informatively with incompatible outcome", { ) ) }) + +test_that("find_tune_id() works", { + # empty input + expect_equal(find_tune_id(list()), NA_character_) + + # handles quosures + x <- rlang::quos(a = 1, b = tune()) + expect_equal(find_tune_id(x), "") + + # non-tunable atomic values + expect_equal(find_tune_id(1), NA_character_) + expect_equal(find_tune_id("a"), NA_character_) + expect_equal(find_tune_id(TRUE), NA_character_) + + # non-tunable names + expect_equal(find_tune_id(quote(x)), NA_character_) + + # nested lists + x <- list(a = 1, b = list(c = hardhat::tune(), d = 2)) + expect_equal(find_tune_id(x), "") + + # tune() without id + expect_equal(find_tune_id(hardhat::tune()), "") + + # tune() with id + expect_equal(find_tune_id(hardhat::tune("test_id")), "test_id") + + # multiple tunable values + x <- list(a = hardhat::tune(), b = hardhat::tune()) + expect_snapshot(error = TRUE, find_tune_id(x)) +})