diff --git a/NEWS.md b/NEWS.md index f5154a03..119647fe 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # workflows (development version) +* Added a standalone file `standalone-input-names.R` with APIs for returning the +names of the predictors in the original data given to `fit()`. + * Each of the `pull_*()` functions soft-deprecated in workflows v0.2.3 now warn on every usage. * `add_recipe()` will now error informatively when supplied a trained recipe (#179). diff --git a/R/standalone-input-names.R b/R/standalone-input-names.R new file mode 100644 index 00000000..71e33e75 --- /dev/null +++ b/R/standalone-input-names.R @@ -0,0 +1,81 @@ +# --- +# repo: tidymodels/workflows +# file: standalone-input-names.R +# last-updated: 2024-01-21 +# license: https://unlicense.org +# requires: cli, rlang +# --- + +# This file provides a portable set of helper functions for determining the +# names of the predictor columns used as inputs into a workflow. + +# ## Changelog +# 2024-01-21 +# * First version + +# nocov start + +# ------------------------------------------------------------------------------ +# Primary functions + +# @param x A _fitted_ workflow or recipe. +# @param call An environment indicating where the top-level function was invoked +# to print out better errors. +# @return A character vector of sorted columns names. + +.get_input_predictors_workflow <- function(x, ..., call = rlang::current_env()) { + check_workflow_fit(x, call = call) + # We can get the columns that are inputs to the recipe but some of these may + # not be predictors. We'll interrogate the recipe and pull out the current + # predictor names from the original input + if ("recipe" %in% names(x$pre$actions)) { + mold <- x$pre$mold + rec <- mold$blueprint$recipe + res <- .get_input_predictors_recipe(rec) + } else { + res <- blueprint_ptype(x) + } + sort(unique(res)) +} + +.get_input_predictors_recipe <- function(x, ..., call = rlang::current_env()) { + check_recipe_fit(x, call = call) + var_info <- x$last_term_info + + keep_rows <- var_info$source == "original" & is_predictor_role(var_info) + var_info <- var_info[keep_rows,] + var_info$variable +} + +.get_input_outcome_workflow <- function(x) { + check_workflow_fit(x) + names(x$pre$mold$blueprint$ptypes$outcomes) +} + +# ------------------------------------------------------------------------------ +# Helper functions + +check_workflow_fit <- function(x, call) { + if (!x$trained) { + cli::cli_abort("The workflow should be trained.", call = call) + } + invisible(NULL) +} + +check_recipe_fit <- function(x, call) { + is_trained <- vapply(x$steps, function(x) x$trained, logical(1)) + if (!all(is_trained)) { + cli::cli_abort("All recipe steps should be trained.", call = call) + } + invisible(NULL) +} + +blueprint_ptype <- function(x) { + names(x$pre$mold$blueprint$ptypes$predictors) +} + +is_predictor_role <- function(x) { + vapply(x$role, function(x) any(x == "predictor"), logical(1)) +} + +# nocov end diff --git a/tests/testthat/_snaps/standalone-input-names.md b/tests/testthat/_snaps/standalone-input-names.md new file mode 100644 index 00000000..dce747e1 --- /dev/null +++ b/tests/testthat/_snaps/standalone-input-names.md @@ -0,0 +1,32 @@ +# get recipe input column names + + Code + workflows:::.get_input_predictors_workflow(workflow) + Condition + Error in `workflows:::.get_input_predictors_workflow()`: + ! The workflow should be trained. + +--- + + Code + workflows:::.get_input_predictors_recipe(rec_with_id) + Condition + Error in `workflows:::.get_input_predictors_recipe()`: + ! All recipe steps should be trained. + +# get formula input column names + + Code + workflows:::.get_input_predictors_workflow(workflow) + Condition + Error in `workflows:::.get_input_predictors_workflow()`: + ! The workflow should be trained. + +# get predictor input column names + + Code + workflows:::.get_input_predictors_workflow(workflow) + Condition + Error in `workflows:::.get_input_predictors_workflow()`: + ! The workflow should be trained. + diff --git a/tests/testthat/test-standalone-input-names.R b/tests/testthat/test-standalone-input-names.R new file mode 100644 index 00000000..2fcc803f --- /dev/null +++ b/tests/testthat/test-standalone-input-names.R @@ -0,0 +1,88 @@ +test_that("get recipe input column names", { + skip_if_not_installed("modeldata") + skip_if_not_installed("recipes") + + library(recipes) + + data(cells, package = "modeldata") + + cells <- cells[, 1:10] + pred_names <- sort(names(cells)[3:10]) + + rec_with_id <- + recipes::recipe(class ~ ., cells) %>% + update_role(case, new_role = "destination") %>% + step_rm(angle_ch_1) %>% + step_pca(all_predictors()) + + workflow <- workflow() + workflow <- add_recipe(workflow, rec_with_id) + workflow <- add_model(workflow, parsnip::logistic_reg()) + workflow_fit <- fit(workflow, cells) + + expect_snapshot( + workflows:::.get_input_predictors_workflow(workflow), + error = TRUE + ) + expect_equal( + workflows:::.get_input_predictors_workflow(workflow_fit), + pred_names + ) + expect_snapshot( + workflows:::.get_input_predictors_recipe(rec_with_id), + error = TRUE + ) + +}) + +test_that("get formula input column names", { + skip_if_not_installed("modeldata") + + data(Chicago, package = "modeldata") + + Chicago <- Chicago[, c("ridership", "date", "Austin")] + pred_names <- sort(c("date", "Austin")) + + workflow <- workflow() + workflow <- add_formula(workflow, ridership ~ .) + workflow <- add_model(workflow, parsnip::linear_reg()) + workflow_fit <- fit(workflow, Chicago) + + expect_snapshot( + workflows:::.get_input_predictors_workflow(workflow), + error = TRUE + ) + expect_equal( + workflows:::.get_input_predictors_workflow(workflow_fit), + pred_names + ) + +}) + + +test_that("get predictor input column names", { + skip_if_not_installed("modeldata") + + data(Chicago, package = "modeldata") + + Chicago <- Chicago[, c("ridership", "date", "Austin")] + pred_names <- sort(c("date", "Austin")) + + workflow <- workflow() + workflow <- + add_variables(workflow, + outcomes = c(ridership), + predictors = c(tidyselect::everything())) + workflow <- add_model(workflow, parsnip::linear_reg()) + workflow_fit <- fit(workflow, Chicago) + + expect_snapshot( + workflows:::.get_input_predictors_workflow(workflow), + error = TRUE + ) + expect_equal( + workflows:::.get_input_predictors_workflow(workflow_fit), + pred_names + ) + +})