Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resample calibration post-processors with an internal split #894

Merged
merged 19 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ Suggests:
modeldata,
scales,
spelling,
tailor,
testthat (>= 3.0.0),
xgboost,
xml2
Remotes:
tidymodels/rsample,
tidymodels/tailor,
tidymodels/workflows,
tidymodels/hardhat
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
tidyverse/tidytemplate
Config/testthat/edition: 3
Expand Down
140 changes: 115 additions & 25 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ tune_grid_loop <- function(resamples,
metrics,
control,
eval_time = NULL,
rng) {
rng,
rset_info) {
fn_tune_grid_loop <- tune_grid_loop_tune

if (workflow_uses_agua(workflow)) {
Expand All @@ -19,7 +20,8 @@ tune_grid_loop <- function(resamples,
metrics,
control,
eval_time,
rng
rng,
rset_info
)

# carry out arranging by id before extracting each element of results (#728)
Expand All @@ -43,7 +45,8 @@ tune_grid_loop_tune <- function(resamples,
metrics,
control,
eval_time = NULL,
rng) {
rng,
rset_info) {
n_resamples <- nrow(resamples)

parallel_over <- control$parallel_over
Expand All @@ -60,7 +63,8 @@ tune_grid_loop_tune <- function(resamples,
control = control,
eval_time = eval_time,
rng = rng,
parallel_over = parallel_over
parallel_over = parallel_over,
rset_info = rset_info
)
}

Expand Down Expand Up @@ -143,7 +147,8 @@ tune_grid_loop_impl <- function(fn_tune_grid_loop_iter,
control,
eval_time = NULL,
rng,
parallel_over) {
parallel_over,
rset_info) {
splits <- resamples$splits
packages <- c(control$pkgs, required_pkgs(workflow))
grid_info <- compute_grid_info(workflow, grid)
Expand Down Expand Up @@ -210,7 +215,8 @@ tune_grid_loop_impl <- function(fn_tune_grid_loop_iter,
eval_time = eval_time,
seed = seed,
metrics_info = metrics_info,
params = params
params = params,
rset_info = rset_info
)
}
)
Expand Down Expand Up @@ -282,7 +288,8 @@ tune_grid_loop_impl <- function(fn_tune_grid_loop_iter,
eval_time = eval_time,
seed = seed,
metrics_info = metrics_info,
params = params
params = params,
rset_info = rset_info
)
}
)
Expand Down Expand Up @@ -333,7 +340,12 @@ tune_grid_loop_iter <- function(split,
eval_time = NULL,
seed,
metrics_info = metrics_info(metrics),
params) {
params,
rset_info = NULL) {
# `split` may be overwritten later on to create an "internal" split for
# post-processing. however, we want the original split to persist so we can
# use it (particularly `labels(split_orig)`) in logging
split_orig <- split

load_pkgs(workflow)
.load_namespace(control$pkgs)
Expand Down Expand Up @@ -373,6 +385,30 @@ tune_grid_loop_iter <- function(split,

training <- rsample::analysis(split)
Copy link
Member

@topepo topepo Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change this name just on principle


if (workflows::should_inner_split(workflow)) {
# if the workflow has a postprocessor that needs training (i.e. calibration),
# further split the analysis data into an "internal" analysis and
# assessment set.
# * the preprocessor and model (excluding the post-processor) are fitted
# on `analysis(split_post)`, the internal analysis set
# * that model generates predictions on `assessment(split_post)`, the
# internal assessment set
# * the post-processor is trained on the predictions generated from the
# internal assessment set
# * the model (including the post-processor) generates predictions on the
# assessment set (not internal, i.e. `assessment(split)`) and those
# predictions are assessed with performance metrics
# todo: check if workflow's `method` is incompatible with `class(split)`?
# todo: workflow's `method` is currently ignored in favor of the one
# automatically dispatched to from `split`. consider this is combination
# with above todo.
split_args <- c(rset_info$att, list(prop = workflow$post$actions$tailor$prop))
split <- rsample::inner_split(split, split_args = split_args)
# todo: this should have a better name (analysis?) -- needs to be
# `training` right now to align with the `training` above
training <- rsample::analysis(split)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
}

# ----------------------------------------------------------------------------
# Preprocessor loop

Expand Down Expand Up @@ -400,7 +436,7 @@ tune_grid_loop_iter <- function(split,
workflow <- .catch_and_log(
.expr = .fit_pre(workflow, training),
control,
split,
split_orig,
iter_msg_preprocessor,
notes = out_notes
)
Expand Down Expand Up @@ -435,7 +471,7 @@ tune_grid_loop_iter <- function(split,
workflow <- .catch_and_log_fit(
.expr = .fit_model(workflow, control_workflow),
control,
split,
split_orig,
iter_msg_model,
notes = out_notes
)
Expand All @@ -460,24 +496,26 @@ tune_grid_loop_iter <- function(split,
iter_grid_model
)

elt_extract <- .catch_and_log(
extract_details(workflow, control$extract),
control,
split,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)
if (!workflows::should_inner_split(workflow)) {
elt_extract <- .catch_and_log(
extract_details(workflow, control$extract),
control,
split_orig,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_orig, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)
}

iter_msg_predictions <- paste(iter_msg_model, "(predictions)")

iter_predictions <- .catch_and_log(
predict_model(split, workflow, iter_grid, metrics, iter_submodels,
metrics_info = metrics_info, eval_time = eval_time),
control,
split,
split_orig,
iter_msg_predictions,
bad_only = TRUE,
notes = out_notes
Expand All @@ -488,14 +526,64 @@ tune_grid_loop_iter <- function(split,
next
}

if (workflows::should_inner_split(workflow)) {
# note that, since we're training a postprocessor, `iter_predictions`
# are the predictions from the internal assessment set rather than the
# assessment set (i.e. `assessment(split_orig)`)

# train the post-processor on the predictions generated from the model
# on the internal assessment set
# todo: this is the same assessment set that `predict_model` makes.
# we're ad-hoc `augment()`ing here, but would be nice to just have
# those predictors
# todo: needs a `.catch_and_log`
# todo: .fit_post currently takes in `assessment(split)` rather than
# a set of predictions, meaning that we predict on `assessment(split)`
# twice :(
internal_assessment <- rsample::assessment(split)
workflow_with_post <-
.fit_post(workflow, dplyr::bind_cols(rsample::assessment(split)))

workflow_with_post <- .fit_finalize(workflow_with_post)

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_orig,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_orig, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)


# generate predictions on the assessment set (not internal,
# i.e. `assessment(split_orig)`) from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(split_orig, workflow_with_post, iter_grid, metrics,
iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_orig,
paste(iter_msg_model, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)

# now, assess those predictions with performance metrics
}

out_metrics <- append_metrics(
collection = out_metrics,
predictions = iter_predictions,
metrics = metrics,
param_names = param_names,
outcome_name = outcome_names,
event_level = event_level,
split = split,
split = split_orig,
.config = iter_config,
metrics_info = metrics_info
)
Expand All @@ -505,7 +593,7 @@ tune_grid_loop_iter <- function(split,
out_predictions <- append_predictions(
collection = out_predictions,
predictions = iter_predictions,
split = split,
split = split_orig,
control = control,
.config = iter_config_metrics
)
Expand All @@ -532,7 +620,8 @@ tune_grid_loop_iter_safely <- function(fn_tune_grid_loop_iter,
eval_time = NULL,
seed,
metrics_info,
params) {
params,
rset_info) {

fn_tune_grid_loop_iter_wrapper <- super_safely(fn_tune_grid_loop_iter)

Expand All @@ -546,7 +635,8 @@ tune_grid_loop_iter_safely <- function(fn_tune_grid_loop_iter,
eval_time,
seed,
metrics_info = metrics_info,
params
params,
rset_info
)

error <- result$error
Expand Down
12 changes: 12 additions & 0 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
y_vals$.row <- orig_rows
res <- dplyr::full_join(res, y_vals, by = ".row")

if (has_postprocessor(workflow)) {
post <- extract_postprocessor(workflow)

if (tailor::tailor_fully_trained(post)) {
res <- predict(post, res)
}
}

# Add implicitly grouped metric data, if applicable
metrics_by <- get_metrics_by(metrics)
if (has_metrics_by(metrics_by)) {
Expand Down Expand Up @@ -628,6 +636,10 @@ has_preprocessor_variables <- function(workflow) {
"variables" %in% names(workflow$pre$actions)
}

has_postprocessor <- function(workflow) {
"tailor" %in% names(workflow$post$actions)
}

has_case_weights <- function(workflow) {
"case_weights" %in% names(workflow$pre$actions)
}
Expand Down
3 changes: 2 additions & 1 deletion R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ tune_grid_workflow <- function(workflow,
metrics = metrics,
eval_time = eval_time,
control = control,
rng = rng
rng = rng,
rset_info = rset_info
)

if (is_cataclysmic(resamples)) {
Expand Down
52 changes: 52 additions & 0 deletions tests/testthat/test-resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,58 @@ test_that("extracted workflow is finalized", {
expect_true(result_workflow$trained)
})

test_that("can use `fit_resamples()` with a workflow - postprocessor (requires training)", {
skip_if_not_installed("tailor")

y <- seq(0, 7, .001)
dat <- data.frame(y = y, x = y + (y-3)^2)

dat

folds <- rsample::vfold_cv(dat, v = 2)

wflow <-
workflows::workflow(
y ~ x,
parsnip::linear_reg()
) %>%
workflows::add_tailor(
tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"),
prop = 2/3,
method = class(folds$splits[[1]])
)

set.seed(1)
tune_res <-
fit_resamples(
wflow,
folds,
control = control_resamples(save_pred = TRUE, extract = identity)
)

tune_preds <-
collect_predictions(tune_res) %>%
dplyr::filter(id == "Fold1")

tune_wflow <-
collect_extracts(tune_res) %>%
pull(.extracts) %>%
`[[`(1)

# mock `tune::tune_grid_loop_iter`'s RNG scheme
set.seed(1)
seed <- generate_seeds(TRUE, 1)[[1]]
old_kind <- RNGkind()[[1]]
assign(".Random.seed", seed, envir = globalenv())

wflow_res <- generics::fit(wflow, rsample::analysis(folds$splits[[1]]))
wflow_preds <- predict(wflow_res, rsample::assessment(folds$splits[[1]]))

tune_wflow$fit$fit$elapsed$elapsed <- wflow_res$fit$fit$elapsed$elapsed
expect_equal(tune_preds$.pred, wflow_preds$.pred)
expect_equal(tune_wflow, wflow_res)
})

# Error capture ----------------------------------------------------------------

test_that("failure in recipe is caught elegantly", {
Expand Down
Loading