Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

APIs for determining original predictor columns #215

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# workflows (development version)

* Added a standalone file `standalone-input-names.R` with APIs for returning the
names of the predictors in the original data given to `fit()`.

* Each of the `pull_*()` functions soft-deprecated in workflows v0.2.3 now warn on every usage.

* `add_recipe()` will now error informatively when supplied a trained recipe (#179).
Expand Down
83 changes: 83 additions & 0 deletions R/standalone-input-names.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# ---
# repo: tidymodels/workflows
# file: standalone-input-names.R
# last-updated: 2024-01-291
# license: https://unlicense.org
# requires: cli, rlang
# ---

# This file provides a portable set of helper functions for determining the
# names of the predictor columns used as inputs into a workflow.

# ## Changelog
# 2024-01-21
# * First version
# 2024-01-29
topepo marked this conversation as resolved.
Show resolved Hide resolved
# * Changes after PR review

# nocov start

# ------------------------------------------------------------------------------
# Primary functions

# @param x A _fitted_ workflow or recipe.
# @param call An environment indicating where the top-level function was invoked
# to print out better errors.
# @return A character vector of sorted columns names.

.get_input_predictors_workflow <- function(x, ..., call = rlang::current_env()) {
check_workflow_fit(x, call = call)
# We can get the columns that are inputs to the recipe but some of these may
# not be predictors. We'll interrogate the recipe and pull out the current
# predictor names from the original input
if ("recipe" %in% names(x$pre$actions)) {
mold <- x$pre$mold
rec <- mold$blueprint$recipe
res <- .get_input_predictors_recipe(rec)
} else {
res <- blueprint_ptype(x)
}
sort(unique(res))
}

.get_input_predictors_recipe <- function(x, ..., call = rlang::current_env()) {
check_recipe_fit(x, call = call)
var_info <- x$last_term_info

keep_rows <- var_info$source == "original" & is_predictor_role(var_info)
var_info <- var_info[keep_rows,]
var_info$variable
}

.get_input_outcome_workflow <- function(x) {
check_workflow_fit(x)
names(x$pre$mold$blueprint$ptypes$outcomes)
}

# ------------------------------------------------------------------------------
# Helper functions

check_workflow_fit <- function(x, call) {
if (!x$trained) {
cli::cli_abort("The workflow should be trained.", call = call)
}
invisible(NULL)
}

check_recipe_fit <- function(x, call) {
is_trained <- vapply(x$steps, function(x) x$trained, logical(1))
if (!all(is_trained)) {
cli::cli_abort("All recipe steps should be trained.", call = call)
}
invisible(NULL)
}

blueprint_ptype <- function(x) {
names(x$pre$mold$blueprint$ptypes$predictors)
}

is_predictor_role <- function(x) {
vapply(x$role, function(x) any(x == "predictor"), logical(1))
}

# nocov end
32 changes: 32 additions & 0 deletions tests/testthat/_snaps/standalone-input-names.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# get recipe input column names

Code
workflows:::.get_input_predictors_workflow(workflow)
Condition
Error in `workflows:::.get_input_predictors_workflow()`:
! The workflow should be trained.

---

Code
workflows:::.get_input_predictors_recipe(rec_with_id)
Condition
Error in `workflows:::.get_input_predictors_recipe()`:
! All recipe steps should be trained.

# get formula input column names

Code
workflows:::.get_input_predictors_workflow(workflow)
Condition
Error in `workflows:::.get_input_predictors_workflow()`:
! The workflow should be trained.

# get predictor input column names

Code
workflows:::.get_input_predictors_workflow(workflow)
Condition
Error in `workflows:::.get_input_predictors_workflow()`:
! The workflow should be trained.

88 changes: 88 additions & 0 deletions tests/testthat/test-standalone-input-names.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
test_that("get recipe input column names", {
skip_if_not_installed("modeldata")
skip_if_not_installed("recipes")

library(recipes)

data(cells, package = "modeldata")

cells <- cells[, 1:10]
pred_names <- sort(names(cells)[3:10])

rec_with_id <-
recipes::recipe(class ~ ., cells) %>%
update_role(case, new_role = "destination") %>%
step_rm(angle_ch_1) %>%
step_pca(all_predictors())

workflow <- workflow()
workflow <- add_recipe(workflow, rec_with_id)
workflow <- add_model(workflow, parsnip::logistic_reg())
workflow_fit <- fit(workflow, cells)

expect_snapshot(
workflows:::.get_input_predictors_workflow(workflow),
error = TRUE
)
expect_equal(
workflows:::.get_input_predictors_workflow(workflow_fit),
pred_names
)
expect_snapshot(
workflows:::.get_input_predictors_recipe(rec_with_id),
error = TRUE
)

})

test_that("get formula input column names", {
skip_if_not_installed("modeldata")

data(Chicago, package = "modeldata")

Chicago <- Chicago[, c("ridership", "date", "Austin")]
pred_names <- sort(c("date", "Austin"))

workflow <- workflow()
workflow <- add_formula(workflow, ridership ~ .)
workflow <- add_model(workflow, parsnip::linear_reg())
workflow_fit <- fit(workflow, Chicago)

expect_snapshot(
workflows:::.get_input_predictors_workflow(workflow),
error = TRUE
)
expect_equal(
workflows:::.get_input_predictors_workflow(workflow_fit),
pred_names
)

})


test_that("get predictor input column names", {
skip_if_not_installed("modeldata")

data(Chicago, package = "modeldata")

Chicago <- Chicago[, c("ridership", "date", "Austin")]
pred_names <- sort(c("date", "Austin"))

workflow <- workflow()
workflow <-
add_variables(workflow,
outcomes = c(ridership),
predictors = c(tidyselect::everything()))
workflow <- add_model(workflow, parsnip::linear_reg())
workflow_fit <- fit(workflow, Chicago)

expect_snapshot(
workflows:::.get_input_predictors_workflow(workflow),
error = TRUE
)
expect_equal(
workflows:::.get_input_predictors_workflow(workflow_fit),
pred_names
)

})
Loading