diff --git a/R/clustering.R b/R/clustering.R index 5ed07f60..dbfa2715 100644 --- a/R/clustering.R +++ b/R/clustering.R @@ -57,7 +57,7 @@ clustering_cv <- function(data, distance_function = "dist", cluster_function = c("kmeans", "hclust"), ...) { - check_repeats(repeats) + check_number_whole(repeats, min = 1) if (!rlang::is_function(cluster_function)) { cluster_function <- rlang::arg_match(cluster_function) diff --git a/R/vfold.R b/R/vfold.R index faded51e..0aa9f0c9 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -72,7 +72,7 @@ vfold_cv <- function(data, v = 10, repeats = 1, } check_strata(strata, data) - check_repeats(repeats) + check_number_whole(repeats, min = 1) if (repeats == 1) { split_objs <- vfold_splits( @@ -213,7 +213,7 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1, pr #' @export group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ..., strata = NULL, pool = 0.1) { check_dots_empty() - check_repeats(repeats) + check_number_whole(repeats, min = 1) group <- validate_group({{ group }}, data) balance <- rlang::arg_match(balance) @@ -331,23 +331,24 @@ add_vfolds <- function(x, v) { } check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) { - if (!is.numeric(v) || length(v) != 1 || v < 2) { - cli_abort("{.arg v} must be a single positive integer greater than 1.", call = call) - } else if (v > max_v) { + check_number_whole(v, min = 2, call = call) + + if (v > max_v) { cli_abort( "The number of {rows} is less than {.arg v} = {.val {v}}.", call = call ) - } else if (prevent_loo && isTRUE(v == max_v)) { + } + if (prevent_loo && isTRUE(v == max_v)) { cli_abort(c( "Leave-one-out cross-validation is not supported by this function.", - "x" = "You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.", - "i" = "Use `loo_cv()` in this case." + "x" = "You set {.arg v} to {.code nrow(data)}, which would result in a leave-one-out cross-validation.", + "i" = "Use {.fn loo_cv} in this case." ), call = call) } } -check_grouped_strata <- function(group, strata, pool, data) { +check_grouped_strata <- function(group, strata, pool, data, call = caller_env()) { strata <- tidyselect::vars_select(names(data), !!enquo(strata)) @@ -363,14 +364,11 @@ check_grouped_strata <- function(group, strata, pool, data) { if (nrow(vctrs::vec_unique(grouped_table)) != nrow(vctrs::vec_unique(grouped_table["group"]))) { - cli_abort("{.arg strata} must be constant across all members of each {.arg group}.") + cli_abort( + "{.field strata} must be constant across all members of each {.field group}.", + call = call + ) } strata } - -check_repeats <- function(repeats, call = rlang::caller_env()) { - if (!is.numeric(repeats) || length(repeats) != 1 || repeats < 1) { - cli_abort("{.arg repeats} must be a single positive integer.", call = call) - } -} diff --git a/tests/testthat/_snaps/clustering.md b/tests/testthat/_snaps/clustering.md index 048a87c1..68c20d93 100644 --- a/tests/testthat/_snaps/clustering.md +++ b/tests/testthat/_snaps/clustering.md @@ -12,7 +12,7 @@ clustering_cv(iris, Sepal.Length, v = -500) Condition Error in `clustering_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number -500. --- @@ -36,7 +36,7 @@ clustering_cv(Orange, v = 1, vars = "Tree") Condition Error in `clustering_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number 1. --- @@ -44,7 +44,7 @@ clustering_cv(Orange, repeats = 0) Condition Error in `clustering_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number larger than or equal to 1, not the number 0. --- @@ -52,7 +52,7 @@ clustering_cv(Orange, repeats = NULL) Condition Error in `clustering_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number, not `NULL`. --- diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index add53300..63664707 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -41,13 +41,13 @@ ! strata cannot be a object. i Use the time or event variable directly. -# bad args +# v arg is checked Code vfold_cv(iris, v = -500) Condition Error in `vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number -500. --- @@ -55,7 +55,7 @@ vfold_cv(iris, v = 1) Condition Error in `vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number 1. --- @@ -63,7 +63,7 @@ vfold_cv(iris, v = NULL) Condition Error in `vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number, not `NULL`. --- @@ -76,36 +76,36 @@ --- Code - vfold_cv(iris, v = 150, repeats = 2) + vfold_cv(mtcars, v = nrow(mtcars)) Condition Error in `vfold_cv()`: - ! Repeated resampling when `v` is 150 would create identical resamples. + ! Leave-one-out cross-validation is not supported by this function. + x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation. + i Use `loo_cv()` in this case. ---- +# repeats arg is checked Code - vfold_cv(Orange, repeats = 0) + vfold_cv(iris, v = 150, repeats = 2) Condition Error in `vfold_cv()`: - ! `repeats` must be a single positive integer. + ! Repeated resampling when `v` is 150 would create identical resamples. --- Code - vfold_cv(Orange, repeats = NULL) + vfold_cv(Orange, repeats = 0) Condition Error in `vfold_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number larger than or equal to 1, not the number 0. --- Code - vfold_cv(mtcars, v = nrow(mtcars)) + vfold_cv(Orange, repeats = NULL) Condition Error in `vfold_cv()`: - ! Leave-one-out cross-validation is not supported by this function. - x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation. - i Use `loo_cv()` in this case. + ! `repeats` must be a whole number, not `NULL`. # printing @@ -191,7 +191,7 @@ group_vfold_cv(Orange, v = 1, group = "Tree") Condition Error in `group_vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number 1. # grouping -- other balance methods @@ -286,6 +286,14 @@ 10 Resample10 # i 20 more rows +# grouping fails for strata not constant across group members + + Code + group_vfold_cv(sample_data, group, v = 5, strata = outcome) + Condition + Error in `group_vfold_cv()`: + ! strata must be constant across all members of each group. + # grouping -- printing Code diff --git a/tests/testthat/test-vfold.R b/tests/testthat/test-vfold.R index f6e7bd1f..56602300 100644 --- a/tests/testthat/test-vfold.R +++ b/tests/testthat/test-vfold.R @@ -104,7 +104,7 @@ test_that("strata arg is checked", { }) }) -test_that("bad args", { +test_that("v arg is checked", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = -500) }) @@ -117,6 +117,12 @@ test_that("bad args", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = 500) }) + expect_snapshot(error = TRUE, { + vfold_cv(mtcars, v = nrow(mtcars)) + }) +}) + +test_that("repeats arg is checked", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = 150, repeats = 2) }) @@ -126,9 +132,6 @@ test_that("bad args", { expect_snapshot(error = TRUE, { vfold_cv(Orange, repeats = NULL) }) - expect_snapshot(error = TRUE, { - vfold_cv(mtcars, v = nrow(mtcars)) - }) }) test_that("printing", { @@ -403,6 +406,35 @@ test_that("grouping -- strata", { ) }) +test_that("grouping fails for strata not constant across group members", { + set.seed(11) + + n_common_class <- 70 + n_rare_class <- 30 + + group_table <- tibble( + group = 1:100, + outcome = sample(c(rep(0, n_common_class), rep(1, n_rare_class))) + ) + observation_table <- tibble( + group = sample(1:100, 1e5, replace = TRUE), + observation = 1:1e5 + ) + sample_data <- dplyr::full_join( + group_table, + observation_table, + by = "group", + multiple = "all" + ) + + # violate requirement + sample_data$outcome[1] <- ifelse(sample_data$outcome[1], 0, 1) + + expect_snapshot(error = TRUE, { + group_vfold_cv(sample_data, group, v = 5, strata = outcome) + }) +}) + test_that("grouping -- repeated", { set.seed(11) rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4)