Skip to content

Commit

Permalink
implement tune_args() and tunable() (#51)
Browse files Browse the repository at this point in the history
* implement `tune_args()` and `tunable()`

* remove redundant method

* namespace fn in test

* more machinery from recipes

* add vctrs to Imports

* test `find_tune_id()`

* add snapshot

* address `vec_rbind()` internal error re: bad recycling

* update for new object structure

* add `extract_parameter_set_dials()` method

* migrate check helper from workflows

* generate snaps
  • Loading branch information
simonpcouch authored Oct 23, 2024
1 parent 317a4db commit 776c2f4
Show file tree
Hide file tree
Showing 23 changed files with 517 additions and 16 deletions.
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ Imports:
purrr,
rlang (>= 1.1.0),
tibble,
tidyselect
tidyselect,
vctrs
Suggests:
dials,
modeldata,
testthat (>= 3.0.0),
workflows
Remotes:
tidymodels/dials#358,
tidymodels/probably,
tidymodels/workflows
Config/testthat/edition: 3
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(extract_parameter_dials,tailor)
S3method(extract_parameter_set_dials,tailor)
S3method(fit,equivocal_zone)
S3method(fit,numeric_calibration)
S3method(fit,numeric_range)
Expand Down Expand Up @@ -33,6 +35,9 @@ S3method(tunable,numeric_range)
S3method(tunable,predictions_custom)
S3method(tunable,probability_calibration)
S3method(tunable,probability_threshold)
S3method(tunable,tailor)
S3method(tune_args,adjustment)
S3method(tune_args,tailor)
export("%>%")
export(adjust_equivocal_zone)
export(adjust_numeric_calibration)
Expand Down
5 changes: 2 additions & 3 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,15 @@ required_pkgs.equivocal_zone <- function(x, ...) {

#' @export
tunable.equivocal_zone <- function(x, ...) {
tibble::new_tibble(list(
tibble::tibble(
name = "buffer",
call_info = list(list(pkg = "dials", fun = "buffer")),
source = "tailor",
component = "equivocal_zone",
component_id = "equivocal_zone"
))
)
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,5 @@ tunable.numeric_calibration <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
5 changes: 2 additions & 3 deletions R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ required_pkgs.numeric_range <- function(x, ...) {

#' @export
tunable.numeric_range <- function(x, ...) {
tibble::new_tibble(list(
tibble::tibble(
name = c("lower_limit", "upper_limit"),
call_info = list(
list(pkg = "dials", fun = "lower_limit"), # todo make these dials functions
Expand All @@ -138,10 +138,9 @@ tunable.numeric_range <- function(x, ...) {
source = "tailor",
component = "numeric_range",
component_id = "numeric_range"
))
)
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,5 @@ tunable.predictions_custom <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,5 @@ tunable.probability_calibration <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
5 changes: 2 additions & 3 deletions R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,15 @@ required_pkgs.probability_threshold <- function(x, ...) {

#' @export
tunable.probability_threshold <- function(x, ...) {
tibble::new_tibble(list(
tibble::tibble(
name = "threshold",
call_info = list(list(pkg = "dials", fun = "threshold")),
source = "tailor",
component = "probability_threshold",
component_id = "probability_threshold"
))
)
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
47 changes: 47 additions & 0 deletions R/extract.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#' @export
extract_parameter_set_dials.tailor <- function(x, ...) {
all_args <- generics::tunable(x)
tuning_param <- generics::tune_args(x)
res <-
dplyr::inner_join(
tuning_param %>% dplyr::select(-tunable),
all_args,
by = c("name", "source", "component", "component_id")
) %>%
dplyr::mutate(object = purrr::map(call_info, eval_call_info))

dials::parameters_constr(
res$name,
res$id,
res$source,
res$component,
res$component_id,
res$object
)
}

eval_call_info <- function(x) {
if (!is.null(x)) {
# Look for other options
allowed_opts <- c("range", "trans", "values")
if (any(names(x) %in% allowed_opts)) {
opts <- x[names(x) %in% allowed_opts]
} else {
opts <- list()
}
res <- try(rlang::eval_tidy(rlang::call2(x$fun, .ns = x$pkg, !!!opts)), silent = TRUE)
if (inherits(res, "try-error")) {
cli::cli_abort(
"Error when calling {.fn {x$fun}}: {as.character(res)}"
)
}
} else {
res <- NA
}
res
}

#' @export
extract_parameter_dials.tailor <- function(x, parameter, ...) {
extract_parameter_dials(extract_parameter_set_dials(x), parameter)
}
38 changes: 37 additions & 1 deletion R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,42 @@ set_tailor_type <- function(object, y, call = caller_env()) {
# todo: where to validate #levels?
# todo setup eval_time
# todo missing methods:
# todo tune_args

#' @export
tune_args.tailor <- function(object, full = FALSE, ...) {
adjustments <- object$adjustments

if (length(adjustments) == 0L) {
return(tune_tbl())
}

res <- purrr::map(object$adjustments, tune_args, full = full)
res <- purrr::list_rbind(res)

tune_tbl(
res$name,
res$tunable,
res$id,
res$source,
res$component,
res$component_id,
full = full
)
}

#' @export
tunable.tailor <- function(x, ...) {
if (length(x$adjustments) == 0) {
res <- no_param
} else {
res <- purrr::map(x$adjustments, tunable)
res <- vctrs::vec_rbind(!!!res)
if (nrow(res) > 0) {
res <- res[!is.na(res$name), ]
}
}
res
}

# todo tidy (this should probably just be `adjustment_orderings()`)
# todo extract_parameter_set_dials
139 changes: 137 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,70 @@
#' @name tailor-internals
NULL


# tuning machinery -------------------------------------------------------------
is_tune <- function(x) {
if (!is.call(x)) {
return(FALSE)
}
isTRUE(identical(quote(tune), x[[1]]))
}

# for adjustments with no tunable parameters
tune_tbl <- function(name = character(), tunable = logical(), id = character(),
source = character(), component = character(),
component_id = character(), full = FALSE, call = caller_env()) {
complete_id <- id[!is.na(id)]
dups <- duplicated(complete_id)

if (any(dups)) {
offenders <- unique(complete_id[dups])
cli::cli_abort(
"{.val {offenders}} {?has a/have} duplicate {.field id} value{?s}.",
call = call
)
}

tune_tbl <-
tibble::tibble(
name = as.character(name),
tunable = as.logical(tunable),
id = as.character(id),
source = as.character(source),
component = as.character(component),
component_id = as.character(component_id)
)

if (!full) {
tune_tbl <- tune_tbl[tune_tbl$tunable, ]
}

tune_tbl
}

#' @export
tune_args.adjustment <- function(object, full = FALSE, ...) {
# Grab the adjustment class before the subset, as that removes the class
adjustment_type <- class(object)[1]

tune_param_list <- tunable(object)$name

# remove the non-tunable arguments as they are not important
adjustment_args <- object$arguments[tune_param_list]

res <- purrr::map_chr(adjustment_args, find_tune_id)
res <- ifelse(res == "", names(res), res)

tune_tbl(
name = names(res),
tunable = unname(!is.na(res)),
id = unname(res),
source = "tailor",
component = adjustment_type,
component_id = adjustment_type,
full = full
)
}

# for adjustments with no tunable parameters
no_param <-
tibble::tibble(
name = character(0),
Expand All @@ -25,6 +79,87 @@ no_param <-
component_id = character(0)
)

find_tune_id <- function(x, arg = caller_arg(x), call = caller_env()) {
if (length(x) == 0L) {
return(NA_character_)
}
if (rlang::is_quosures(x)) {
.x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE)
if (inherits(.x, "try-error")) {
x <- purrr::map(x, rlang::quo_get_expr)
} else {
x <- .x
}
}
id <- tune_id(x, call = call)

if (!is.na(id)) {
return(id)
}

if (is.atomic(x) | is.name(x) | length(x) == 1) {
return(NA_character_)
}

tunable_elems <- vector("character", length = length(x))
for (i in seq_along(x)) {
tunable_elems[i] <- find_tune_id(x[[i]], call = call)
}
tunable_elems <- tunable_elems[!is.na(tunable_elems)]

if (length(tunable_elems) == 0) {
tunable_elems <- NA_character_
}

if (sum(tunable_elems == "", na.rm = TRUE) > 1) {
offenders <- paste0(deparse(x), collapse = "")
cli::cli_abort(
c(
"Only one tunable value is currently allowed per argument.",
"{.arg {arg}} has {.code {offenders}}."
),
call = call
)
}

return(tunable_elems)
}

tune_id <- function(x, call = caller_env()) {
if (is.null(x)) {
return(NA_character_)
} else {
if (rlang::is_quosures(x)) {
.x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE)
if (inherits(.x, "try-error")) {
x <- purrr::map(x, rlang::quo_get_expr)
} else {
x <- .x
}
if (is.null(x)) {
return(NA_character_)
}
}

if (is.call(x)) {
if (rlang::call_name(x) == "tune") {
if (length(x) > 1) {
return(x[[2]])
} else {
return("")
}

return(x$id)
} else {
return(NA_character_)
}
}
}

NA_character_
}

# new_adjustment -------------------------------------------------------------
# These values are used to specify "what will we need for the adjustment?" and
# "what will we change?". For the outputs, we cannot change the probabilities
# without changing the classes. This is important because we are going to have
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/extract.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# extract single parameter from tailor with no adjustments

Code
extract_parameter_dials(tailor(), parameter = "none there")
Condition
Error in `extract_parameter_dials()`:
! No parameter exists with id "none there".

# extract single parameter from tailor with no tunable parameters

Code
extract_parameter_dials(tlr, parameter = "none there")
Condition
Error in `extract_parameter_dials()`:
! No parameter exists with id "none there".

9 changes: 9 additions & 0 deletions tests/testthat/_snaps/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,12 @@
Error in `fit()`:
! Only factor and numeric outcomes are currently supported.

# find_tune_id() works

Code
find_tune_id(x)
Condition
Error:
! Only one tunable value is currently allowed per argument.
`x` has `list(a = tune(), b = tune())`.

Loading

0 comments on commit 776c2f4

Please sign in to comment.