diff --git a/NAMESPACE b/NAMESPACE index 19772f54..d01b9f3f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -25,6 +25,12 @@ S3method(dim,initial_validation_split) S3method(dim,rsplit) S3method(get_rsplit,default) S3method(get_rsplit,rset) +S3method(inner_split,apparent_split) +S3method(inner_split,clustering_split) +S3method(inner_split,group_mc_split) +S3method(inner_split,group_vfold_split) +S3method(inner_split,mc_split) +S3method(inner_split,vfold_split) S3method(int_bca,bootstraps) S3method(int_pctl,bootstraps) S3method(int_t,bootstraps) @@ -343,6 +349,7 @@ S3method(vec_restore,validation_split) S3method(vec_restore,validation_time_split) S3method(vec_restore,vfold_cv) export(.get_fingerprint) +export(.get_split_args) export(add_resample_id) export(all_of) export(analysis) @@ -368,6 +375,7 @@ export(initial_split) export(initial_time_split) export(initial_validation_split) export(initial_validation_time_split) +export(inner_split) export(int_bca) export(int_pctl) export(int_t) diff --git a/NEWS.md b/NEWS.md index 9404b5e8..75785b90 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # rsample (development version) +* The new `inner_split()` function and its methods for various resamples is for usage in tune to create a inner resample of the analysis set to fit the preprocessor and model on one part and the post-processor on the other part (#483). + ## Bug fixes * `vfold_cv()` now utilizes the `breaks` argument correctly for repeated cross-validation (@ZWael, #471). diff --git a/R/inner_split.R b/R/inner_split.R new file mode 100644 index 00000000..6cd442df --- /dev/null +++ b/R/inner_split.R @@ -0,0 +1,159 @@ +#' Inner split of the analysis set for fitting a post-processor +#' +#' @param x An `rsplit` object. +#' @param split_args A list of arguments to be used for the inner split. +#' @param ... Not currently used. +#' @return An `rsplit` object. +#' @details +#' `rsplit` objects live most commonly inside of an `rset` object. The +#' `split_args` argument can be the output of [.get_split_args()] on that +#' corresponding `rset` object, even if some of the arguments used to creat the +#' `rset` object are not needed for the inner split. +#' * For `mc_split` and `group_mc_split` objects, `inner_split()` will ignore +#' `split_args$times`. +#' * For `vfold_split` and `group_vfold_split` objects, it will ignore +#' `split_args$times` and `split_args$repeats`. `split_args$v` will be used to +#' set `split_args$prop` to `1 - 1/v` if `prop` is not already set and otherwise +#' ignored. The method +#' for `group_vfold_split` will always use `split_args$balance = NULL`. +#' * For `clustering_split` objects, it will ignore `split_args$repeats`. +#' +#' @keywords internal +#' @export +inner_split <- function(x, ...) { + UseMethod("inner_split") +} + +# mc --------------------------------------------------------------------- + +#' @rdname inner_split +#' @export +inner_split.mc_split <- function(x, split_args, ...) { + check_dots_empty() + + analysis_set <- analysis(x) + + split_args$times <- 1 + split_inner <- rlang::inject( + mc_splits(analysis_set, !!!split_args) + ) + split_inner <- split_inner$splits[[1]] + + class_inner <- paste0(class(x)[1], "_inner") + split_inner <- add_class(split_inner, class_inner) + split_inner +} + +#' @rdname inner_split +#' @export +inner_split.group_mc_split <- function(x, split_args, ...) { + check_dots_empty() + + analysis_set <- analysis(x) + + split_args$times <- 1 + split_inner <- rlang::inject( + group_mc_splits(analysis_set, !!!split_args) + ) + split_inner <- split_inner$splits[[1]] + + class_inner <- paste0(class(x)[1], "_inner") + split_inner <- add_class(split_inner, class_inner) + split_inner +} + + +# vfold ------------------------------------------------------------------ + +#' @rdname inner_split +#' @export +inner_split.vfold_split <- function(x, split_args, ...) { + check_dots_empty() + + analysis_set <- analysis(x) + + # TODO should this be done outside of rsample, + # in workflows or tune? + if (is.null(split_args$prop)) { + split_args$prop <- 1 - 1/split_args$v + } + # use mc_splits for a random split + split_args$times <- 1 + split_args$v <- NULL + split_args$repeats <- NULL + split_inner <- rlang::inject( + mc_splits(analysis_set, !!!split_args) + ) + split_inner <- split_inner$splits[[1]] + + class_inner <- paste0(class(x)[1], "_inner") + class(split_inner) <- c(class_inner, class(x)) + split_inner +} + +#' @rdname inner_split +#' @export +inner_split.group_vfold_split <- function(x, split_args, ...) { + check_dots_empty() + + analysis_set <- analysis(x) + + # TODO should this be done outside of rsample, + # in workflows or tune? + if (is.null(split_args$prop)) { + split_args$prop <- 1 - 1/split_args$v + } + + # use group_mc_splits for a random split + split_args$times <- 1 + split_args$v <- NULL + split_args$repeats <- NULL + split_args$balance <- NULL + split_inner <- rlang::inject( + group_mc_splits(analysis_set, !!!split_args) + ) + split_inner <- split_inner$splits[[1]] + + class_inner <- paste0(class(x)[1], "_inner") + class(split_inner) <- c(class_inner, class(x)) + split_inner +} + +# clustering ------------------------------------------------------------- + +#' @rdname inner_split +#' @export +inner_split.clustering_split <- function(x, split_args, ...) { + check_dots_empty() + + analysis_set <- analysis(x) + + # TODO: reduce the number of clusters by 1 in tune? + split_args$repeats <- 1 + split_inner <- rlang::inject( + clustering_cv(analysis_set, !!!split_args) + ) + split_inner <- split_inner$splits[[1]] + + class_inner <- paste0(class(x)[1], "_inner") + class(split_inner) <- c(class_inner, class(x)) + split_inner +} + + +# apparent --------------------------------------------------------------- + +#' @rdname inner_split +#' @export +inner_split.apparent_split <- function(x, ...) { + check_dots_empty() + + analysis_set <- analysis(x) + + split_inner <- apparent(analysis_set) + split_inner <- split_inner$splits[[1]] + + class_inner <- paste0(class(x)[1], "_inner") + class(split_inner) <- c(class_inner, class(x)) + split_inner +} diff --git a/R/misc.R b/R/misc.R index e1dc44e2..3b1aa7bf 100644 --- a/R/misc.R +++ b/R/misc.R @@ -271,13 +271,9 @@ reshuffle_rset <- function(rset) { } } - arguments <- attributes(rset) - useful_arguments <- names(formals(arguments$class[[1]])) - useful_arguments <- arguments[useful_arguments] - useful_arguments <- useful_arguments[!is.na(names(useful_arguments))] - if (identical(useful_arguments$strata, FALSE)) { - useful_arguments$strata <- NULL - } else if (identical(useful_arguments$strata, TRUE)) { + rset_type <- class(rset)[[1]] + split_arguments <- .get_split_args(rset) + if (identical(split_arguments$strata, TRUE)) { rlang::abort( "Cannot reshuffle this rset (`attr(rset, 'strata')` is `TRUE`, not a column identifier)", i = "If the original object was created with an older version of rsample, try recreating it with the newest version of the package" @@ -285,8 +281,8 @@ reshuffle_rset <- function(rset) { } do.call( - arguments$class[[1]], - c(list(data = rset$splits[[1]]$data), useful_arguments) + rset_type, + c(list(data = rset$splits[[1]]$data), split_arguments) ) } @@ -299,6 +295,22 @@ non_random_classes <- c( "validation_set" ) +#' Get the split arguments from an rset +#' @param rset An `rset` object. +#' @return A list of arguments used to create the rset. +#' @keywords internal +#' @export +.get_split_args <- function(rset) { + all_attributes <- attributes(rset) + args <- names(formals(all_attributes$class[[1]])) + split_args <- all_attributes[args] + split_args <- split_args[!is.na(names(split_args))] + if (identical(split_args$strata, FALSE)) { + split_args$strata <- NULL + } + split_args +} + #' Retrieve individual rsplits objects from an rset #' #' @param x The `rset` object to retrieve an rsplit from. diff --git a/man/dot-get_split_args.Rd b/man/dot-get_split_args.Rd new file mode 100644 index 00000000..c4bccaaf --- /dev/null +++ b/man/dot-get_split_args.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{.get_split_args} +\alias{.get_split_args} +\title{Get the split arguments from an rset} +\usage{ +.get_split_args(rset) +} +\arguments{ +\item{rset}{An \code{rset} object.} +} +\value{ +A list of arguments used to create the rset. +} +\description{ +Get the split arguments from an rset +} +\keyword{internal} diff --git a/man/inner_split.Rd b/man/inner_split.Rd new file mode 100644 index 00000000..b4114927 --- /dev/null +++ b/man/inner_split.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/inner_split.R +\name{inner_split} +\alias{inner_split} +\alias{inner_split.mc_split} +\alias{inner_split.group_mc_split} +\alias{inner_split.vfold_split} +\alias{inner_split.group_vfold_split} +\alias{inner_split.clustering_split} +\alias{inner_split.apparent_split} +\title{Inner split of the analysis set for fitting a post-processor} +\usage{ +inner_split(x, ...) + +\method{inner_split}{mc_split}(x, split_args, ...) + +\method{inner_split}{group_mc_split}(x, split_args, ...) + +\method{inner_split}{vfold_split}(x, split_args, ...) + +\method{inner_split}{group_vfold_split}(x, split_args, ...) + +\method{inner_split}{clustering_split}(x, split_args, ...) + +\method{inner_split}{apparent_split}(x, ...) +} +\arguments{ +\item{x}{An \code{rsplit} object.} + +\item{...}{Not currently used.} + +\item{split_args}{A list of arguments to be used for the inner split.} +} +\value{ +An \code{rsplit} object. +} +\description{ +Inner split of the analysis set for fitting a post-processor +} +\details{ +\code{rsplit} objects live most commonly inside of an \code{rset} object. The +\code{split_args} argument can be the output of \code{\link[=.get_split_args]{.get_split_args()}} on that +corresponding \code{rset} object, even if some of the arguments used to creat the +\code{rset} object are not needed for the inner split. +\itemize{ +\item For \code{mc_split} and \code{group_mc_split} objects, \code{inner_split()} will ignore +\code{split_args$times}. +\item For \code{vfold_split} and \code{group_vfold_split} objects, it will ignore +\code{split_args$times} and \code{split_args$repeats}. \code{split_args$v} will be used to +set \code{split_args$prop} to \code{1 - 1/v} if \code{prop} is not already set and otherwise +ignored. The method +for \code{group_vfold_split} will always use \code{split_args$balance = NULL}. +\item For \code{clustering_split} objects, it will ignore \code{split_args$repeats}. +} +} +\keyword{internal} diff --git a/tests/testthat/test-inner_split.R b/tests/testthat/test-inner_split.R new file mode 100644 index 00000000..61153f2c --- /dev/null +++ b/tests/testthat/test-inner_split.R @@ -0,0 +1,165 @@ + +# mc --------------------------------------------------------------------- + +test_that("mc_split", { + set.seed(11) + r_set <- mc_cv(warpbreaks) + split_args <- .get_split_args(r_set) + r_split <- get_rsplit(r_set, 1) + + isplit <- inner_split(r_split, split_args) + + expect_identical( + isplit$data, + analysis(r_split) + ) + + expect_identical( + analysis(isplit), + isplit$data[isplit$in_id, ], + ignore_attr = "row.names" + ) + expect_identical( + assessment(isplit), + isplit$data[isplit$out_id, ], + ignore_attr = "row.names" + ) +}) + +test_that("group_mc_split", { + skip_if_not_installed("modeldata") + + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(11) + r_set <- group_mc_cv(ames, "MS_SubClass") + split_args <- .get_split_args(r_set) + r_split <- get_rsplit(r_set, 1) + + isplit <- inner_split(r_split, split_args) + + expect_identical( + isplit$data, + analysis(r_split) + ) + + expect_identical( + analysis(isplit), + isplit$data[isplit$in_id, ], + ignore_attr = "row.names" + ) + expect_identical( + assessment(isplit), + isplit$data[isplit$out_id, ], + ignore_attr = "row.names" + ) +}) + + +# vfold ------------------------------------------------------------------ + +test_that("vfold_split", { + set.seed(11) + r_set <- vfold_cv(warpbreaks, v = 5) + split_args <- .get_split_args(r_set) + r_split <- get_rsplit(r_set, 1) + + isplit <- inner_split(r_split, split_args) + + expect_identical( + isplit$data, + analysis(r_split) + ) + + expect_identical( + analysis(isplit), + isplit$data[isplit$in_id, ], + ignore_attr = "row.names" + ) + expect_identical( + assessment(isplit), + isplit$data[isplit$out_id, ], + ignore_attr = "row.names" + ) +}) + +test_that("group_vfold_split", { + skip_if_not_installed("modeldata") + + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(11) + r_set <- group_vfold_cv(ames, "MS_SubClass") + split_args <- .get_split_args(r_set) + r_split <- get_rsplit(r_set, 1) + + isplit <- inner_split(r_split, split_args) + + expect_identical( + isplit$data, + analysis(r_split) + ) + + expect_identical( + analysis(isplit), + isplit$data[isplit$in_id, ], + ignore_attr = "row.names" + ) + expect_identical( + assessment(isplit), + isplit$data[isplit$out_id, ], + ignore_attr = "row.names" + ) +}) + + +# clustering ------------------------------------------------------------- + +test_that("clustering_split", { + set.seed(11) + r_set <- clustering_cv(warpbreaks, vars = breaks, v = 5) + split_args <- .get_split_args(r_set) + r_split <- get_rsplit(r_set, 1) + + isplit <- inner_split(r_split, split_args) + + expect_identical( + isplit$data, + analysis(r_split) + ) + + expect_identical( + analysis(isplit), + isplit$data[isplit$in_id, ], + ignore_attr = "row.names" + ) + expect_identical( + assessment(isplit), + isplit$data[-isplit$in_id, ], + ignore_attr = "row.names" + ) +}) + +# apparent --------------------------------------------------------------- + +test_that("apparent_split", { + set.seed(11) + r_set <- apparent(warpbreaks) + r_split <- get_rsplit(r_set, 1) + + isplit <- inner_split(r_split) + + expect_identical( + isplit$data, + analysis(r_split) + ) + + expect_identical( + analysis(isplit), + analysis(r_split) + ) + expect_identical( + assessment(isplit), + analysis(r_split) + ) +})