Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement tune_args() and tunable() #51

Merged
merged 12 commits into from
Oct 23, 2024
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ Imports:
purrr,
rlang (>= 1.1.0),
tibble,
tidyselect
tidyselect,
vctrs
Suggests:
dials,
modeldata,
testthat (>= 3.0.0),
workflows
Remotes:
tidymodels/dials#358,
tidymodels/probably,
tidymodels/workflows
Config/testthat/edition: 3
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(extract_parameter_dials,tailor)
S3method(extract_parameter_set_dials,tailor)
S3method(fit,equivocal_zone)
S3method(fit,numeric_calibration)
S3method(fit,numeric_range)
Expand Down Expand Up @@ -33,6 +35,9 @@ S3method(tunable,numeric_range)
S3method(tunable,predictions_custom)
S3method(tunable,probability_calibration)
S3method(tunable,probability_threshold)
S3method(tunable,tailor)
S3method(tune_args,adjustment)
S3method(tune_args,tailor)
export("%>%")
export(adjust_equivocal_zone)
export(adjust_numeric_calibration)
Expand Down
5 changes: 2 additions & 3 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,15 @@ required_pkgs.equivocal_zone <- function(x, ...) {

#' @export
tunable.equivocal_zone <- function(x, ...) {
tibble::new_tibble(list(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no arguments here, but I'm curious why you made that change. Is it for readability (since these aren't likely to be called a large number of times)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for calling that out. I'd like to revisit that and return to using new_tibble(), though there was a vector recycling issue that resulted in an interval vctrs error later on; I think we'll just need to do some manual rep()ping.

tibble::tibble(
name = "buffer",
call_info = list(list(pkg = "dials", fun = "buffer")),
source = "tailor",
component = "equivocal_zone",
component_id = "equivocal_zone"
))
)
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,5 @@ tunable.numeric_calibration <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
5 changes: 2 additions & 3 deletions R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ required_pkgs.numeric_range <- function(x, ...) {

#' @export
tunable.numeric_range <- function(x, ...) {
tibble::new_tibble(list(
tibble::tibble(
name = c("lower_limit", "upper_limit"),
call_info = list(
list(pkg = "dials", fun = "lower_limit"), # todo make these dials functions
Expand All @@ -138,10 +138,9 @@ tunable.numeric_range <- function(x, ...) {
source = "tailor",
component = "numeric_range",
component_id = "numeric_range"
))
)
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,5 @@ tunable.predictions_custom <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,5 @@ tunable.probability_calibration <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
5 changes: 2 additions & 3 deletions R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,15 @@ required_pkgs.probability_threshold <- function(x, ...) {

#' @export
tunable.probability_threshold <- function(x, ...) {
tibble::new_tibble(list(
tibble::tibble(
name = "threshold",
call_info = list(list(pkg = "dials", fun = "threshold")),
source = "tailor",
component = "probability_threshold",
component_id = "probability_threshold"
))
)
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
47 changes: 47 additions & 0 deletions R/extract.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#' @export
extract_parameter_set_dials.tailor <- function(x, ...) {
all_args <- generics::tunable(x)
tuning_param <- generics::tune_args(x)
res <-
dplyr::inner_join(
tuning_param %>% dplyr::select(-tunable),
all_args,
by = c("name", "source", "component", "component_id")
) %>%
dplyr::mutate(object = purrr::map(call_info, eval_call_info))

dials::parameters_constr(
res$name,
res$id,
res$source,
res$component,
res$component_id,
res$object
)
}

eval_call_info <- function(x) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make a small standalone file? If not, add some notes/link to original and noocv?

(same for other replicas below)

if (!is.null(x)) {
# Look for other options
allowed_opts <- c("range", "trans", "values")
if (any(names(x) %in% allowed_opts)) {
opts <- x[names(x) %in% allowed_opts]
} else {
opts <- list()
}
res <- try(rlang::eval_tidy(rlang::call2(x$fun, .ns = x$pkg, !!!opts)), silent = TRUE)
if (inherits(res, "try-error")) {
cli::cli_abort(
"Error when calling {.fn {x$fun}}: {as.character(res)}"
)
}
} else {
res <- NA
}
res
}

#' @export
extract_parameter_dials.tailor <- function(x, parameter, ...) {
extract_parameter_dials(extract_parameter_set_dials(x), parameter)
}
38 changes: 37 additions & 1 deletion R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,42 @@ set_tailor_type <- function(object, y, call = caller_env()) {
# todo: where to validate #levels?
# todo setup eval_time
# todo missing methods:
# todo tune_args

#' @export
tune_args.tailor <- function(object, full = FALSE, ...) {
adjustments <- object$adjustments

if (length(adjustments) == 0L) {
return(tune_tbl())
}

res <- purrr::map(object$adjustments, tune_args, full = full)
res <- purrr::list_rbind(res)

tune_tbl(
res$name,
res$tunable,
res$id,
res$source,
res$component,
res$component_id,
full = full
)
}

#' @export
tunable.tailor <- function(x, ...) {
if (length(x$adjustments) == 0) {
res <- no_param
} else {
res <- purrr::map(x$adjustments, tunable)
res <- vctrs::vec_rbind(!!!res)
if (nrow(res) > 0) {
res <- res[!is.na(res$name), ]
}
}
res
}

# todo tidy (this should probably just be `adjustment_orderings()`)
# todo extract_parameter_set_dials
139 changes: 137 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,70 @@
#' @name tailor-internals
NULL


# tuning machinery -------------------------------------------------------------
is_tune <- function(x) {
if (!is.call(x)) {
return(FALSE)
}
isTRUE(identical(quote(tune), x[[1]]))
}

# for adjustments with no tunable parameters
tune_tbl <- function(name = character(), tunable = logical(), id = character(),
source = character(), component = character(),
component_id = character(), full = FALSE, call = caller_env()) {
complete_id <- id[!is.na(id)]
dups <- duplicated(complete_id)

if (any(dups)) {
offenders <- unique(complete_id[dups])
cli::cli_abort(
"{.val {offenders}} {?has a/have} duplicate {.field id} value{?s}.",
call = call
)
}

tune_tbl <-
tibble::tibble(
name = as.character(name),
tunable = as.logical(tunable),
id = as.character(id),
source = as.character(source),
component = as.character(component),
component_id = as.character(component_id)
)

if (!full) {
tune_tbl <- tune_tbl[tune_tbl$tunable, ]
}

tune_tbl
}

#' @export
tune_args.adjustment <- function(object, full = FALSE, ...) {
# Grab the adjustment class before the subset, as that removes the class
adjustment_type <- class(object)[1]

tune_param_list <- tunable(object)$name

# remove the non-tunable arguments as they are not important
adjustment_args <- object$arguments[tune_param_list]

res <- purrr::map_chr(adjustment_args, find_tune_id)
res <- ifelse(res == "", names(res), res)

tune_tbl(
name = names(res),
tunable = unname(!is.na(res)),
id = unname(res),
source = "tailor",
component = adjustment_type,
component_id = adjustment_type,
full = full
)
}

# for adjustments with no tunable parameters
no_param <-
tibble::tibble(
name = character(0),
Expand All @@ -25,6 +79,87 @@ no_param <-
component_id = character(0)
)

find_tune_id <- function(x, arg = caller_arg(x), call = caller_env()) {
if (length(x) == 0L) {
return(NA_character_)
}
if (rlang::is_quosures(x)) {
.x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE)
if (inherits(.x, "try-error")) {
x <- purrr::map(x, rlang::quo_get_expr)
} else {
x <- .x
}
}
id <- tune_id(x, call = call)

if (!is.na(id)) {
return(id)
}

if (is.atomic(x) | is.name(x) | length(x) == 1) {
return(NA_character_)
}

tunable_elems <- vector("character", length = length(x))
for (i in seq_along(x)) {
tunable_elems[i] <- find_tune_id(x[[i]], call = call)
}
tunable_elems <- tunable_elems[!is.na(tunable_elems)]

if (length(tunable_elems) == 0) {
tunable_elems <- NA_character_
}

if (sum(tunable_elems == "", na.rm = TRUE) > 1) {
offenders <- paste0(deparse(x), collapse = "")
cli::cli_abort(
c(
"Only one tunable value is currently allowed per argument.",
"{.arg {arg}} has {.code {offenders}}."
),
call = call
)
}

return(tunable_elems)
}

tune_id <- function(x, call = caller_env()) {
if (is.null(x)) {
return(NA_character_)
} else {
if (rlang::is_quosures(x)) {
.x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE)
if (inherits(.x, "try-error")) {
x <- purrr::map(x, rlang::quo_get_expr)
} else {
x <- .x
}
if (is.null(x)) {
return(NA_character_)
}
}

if (is.call(x)) {
if (rlang::call_name(x) == "tune") {
if (length(x) > 1) {
return(x[[2]])
} else {
return("")
}

return(x$id)
} else {
return(NA_character_)
}
}
}

NA_character_
}

# new_adjustment -------------------------------------------------------------
# These values are used to specify "what will we need for the adjustment?" and
# "what will we change?". For the outputs, we cannot change the probabilities
# without changing the classes. This is important because we are going to have
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/extract.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# extract single parameter from tailor with no adjustments

Code
extract_parameter_dials(tailor(), parameter = "none there")
Condition
Error in `extract_parameter_dials()`:
! No parameter exists with id "none there".

# extract single parameter from tailor with no tunable parameters

Code
extract_parameter_dials(tlr, parameter = "none there")
Condition
Error in `extract_parameter_dials()`:
! No parameter exists with id "none there".

9 changes: 9 additions & 0 deletions tests/testthat/_snaps/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,12 @@
Error in `fit()`:
! Only factor and numeric outcomes are currently supported.

# find_tune_id() works

Code
find_tune_id(x)
Condition
Error:
! Only one tunable value is currently allowed per argument.
`x` has `list(a = tune(), b = tune())`.

Loading
Loading