From 1702751241f0ead7a9f562c421bdc6a56591d608 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Tue, 10 Dec 2024 14:11:36 -0600 Subject: [PATCH] allow regression `adjust_predictions_custom()` without `type` (closes #61) --- R/tailor.R | 4 ++-- .../testthat/test-adjust-predictions-custom.R | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/R/tailor.R b/R/tailor.R index 7678acd..5274f79 100644 --- a/R/tailor.R +++ b/R/tailor.R @@ -164,8 +164,8 @@ fit.tailor <- function(object, .data, outcome, estimate, probabilities = c(), columns$estimate <- names(tidyselect::eval_select(enquo(estimate), .data)) check_selection(enquo(estimate), columns$estimate, "estimate") columns$probabilities <- names(tidyselect::eval_select(enquo(probabilities), .data)) - if (any(c("probability", "everything") %in% - purrr::map_chr(object$adjustments, purrr::pluck, "inputs"))) { + if ("probability" %in% + purrr::map_chr(object$adjustments, purrr::pluck, "inputs")) { check_selection(enquo(probabilities), columns$probabilities, "probabilities") for (col in columns$probabilities) { check_variable_type(.data[[col]], "probability", "probabilities") diff --git a/tests/testthat/test-adjust-predictions-custom.R b/tests/testthat/test-adjust-predictions-custom.R index 8643f32..c346696 100644 --- a/tests/testthat/test-adjust-predictions-custom.R +++ b/tests/testthat/test-adjust-predictions-custom.R @@ -40,6 +40,24 @@ test_that("basic adjust_predictions_custom() usage works", { ) }) +test_that("adjust_predictions_custom() for numerics works without setting type (#61)", { + library(tibble) + + set.seed(1) + d_calibration <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100)) + d_test <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100)) + + expect_no_error({ + tlr <- + tailor() %>% + adjust_numeric_calibration() %>% + adjust_numeric_range(lower_limit = 2) %>% + adjust_predictions_custom(squared = y_pred^2) + + tlr_fit <- fit(tlr, d_calibration, outcome = y, estimate = y_pred) + }) +}) + test_that("adjustment printing", { expect_snapshot(tailor() %>% adjust_predictions_custom()) })