Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potato/inner/calibration split #483

Merged
merged 11 commits into from
May 23, 2024
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ S3method(dim,initial_validation_split)
S3method(dim,rsplit)
S3method(get_rsplit,default)
S3method(get_rsplit,rset)
S3method(inner_split,apparent_split)
S3method(inner_split,clustering_split)
S3method(inner_split,group_mc_split)
S3method(inner_split,group_vfold_split)
S3method(inner_split,mc_split)
S3method(inner_split,vfold_split)
S3method(int_bca,bootstraps)
S3method(int_pctl,bootstraps)
S3method(int_t,bootstraps)
Expand Down Expand Up @@ -343,6 +349,7 @@ S3method(vec_restore,validation_split)
S3method(vec_restore,validation_time_split)
S3method(vec_restore,vfold_cv)
export(.get_fingerprint)
export(.get_split_args)
export(add_resample_id)
export(all_of)
export(analysis)
Expand All @@ -368,6 +375,7 @@ export(initial_split)
export(initial_time_split)
export(initial_validation_split)
export(initial_validation_time_split)
export(inner_split)
export(int_bca)
export(int_pctl)
export(int_t)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# rsample (development version)

* The new `inner_split()` function and its methods for various resamples is for usage in tune to create a inner resample of the analysis set to fit the preprocessor and model on one part and the post-processor on the other part (#483).

## Bug fixes

* `vfold_cv()` now utilizes the `breaks` argument correctly for repeated cross-validation (@ZWael, #471).
Expand Down
159 changes: 159 additions & 0 deletions R/inner_split.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#' Inner split of the analysis set for fitting a post-processor
#'
#' @param x An `rsplit` object.
#' @param split_args A list of arguments to be used for the inner split.
#' @param ... Not currently used.
#' @return An `rsplit` object.
hfrick marked this conversation as resolved.
Show resolved Hide resolved
#' @details
#' `rsplit` objects live most commonly inside of an `rset` object. The
#' `split_args` argument can be the output of [.get_split_args()] on that
#' corresponding `rset` object, even if some of the arguments used to creat the
#' `rset` object are not needed for the inner split.
#' * For `mc_split` and `group_mc_split` objects, `inner_split()` will ignore
#' `split_args$times`.
#' * For `vfold_split` and `group_vfold_split` objects, it will ignore
#' `split_args$times` and `split_args$repeats`. `split_args$v` will be used to
#' set `split_args$prop` to `1 - 1/v` if `prop` is not already set and otherwise
#' ignored. The method
#' for `group_vfold_split` will always use `split_args$balance = NULL`.
#' * For `clustering_split` objects, it will ignore `split_args$repeats`.
#'
#' @keywords internal
#' @export
inner_split <- function(x, ...) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we happy with that name? We can also go back and change it later (like container -> tailor).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do kind of appreciate that inner_split() gives the vibe that these methods are for internal use. inner_split() feels good to me! I'm open to other options but would give preference to names that hint these are for expert use only.

UseMethod("inner_split")
}

# mc ---------------------------------------------------------------------

#' @rdname inner_split
#' @export
inner_split.mc_split <- function(x, split_args, ...) {
check_dots_empty()

analysis_set <- analysis(x)

split_args$times <- 1
split_inner <- rlang::inject(
mc_splits(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not to be pedantic, but wouldn't class_inner by definition be "mc_split_inner"? same for other class_inner. While i appreciate the same code being used across, I think we could just note the class directly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that in reference to this particular method or for all the methods?

Since I had just come across #478, I opted for constructing the class rather than writing it out manually here to make sure it would always stay in sync with the class of the input x. In terms of readability, I would say that the class of that one is fairly easy to see from the S3 dispatch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what i meant, is that each paste0(class(x)[1], "_inner") in this file could be swapped with a mc_split_inner, apparent_split_inner, etc etc as they are called inside s3 methods, on the object that drives the s3 dispatch

split_inner <- add_class(split_inner, class_inner)
split_inner
}

#' @rdname inner_split
#' @export
inner_split.group_mc_split <- function(x, split_args, ...) {
check_dots_empty()

analysis_set <- analysis(x)

split_args$times <- 1
split_inner <- rlang::inject(
group_mc_splits(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
split_inner <- add_class(split_inner, class_inner)
split_inner
}


# vfold ------------------------------------------------------------------

#' @rdname inner_split
#' @export
inner_split.vfold_split <- function(x, split_args, ...) {
check_dots_empty()

analysis_set <- analysis(x)

# TODO should this be done outside of rsample,
# in workflows or tune?
hfrick marked this conversation as resolved.
Show resolved Hide resolved
if (is.null(split_args$prop)) {
split_args$prop <- 1 - 1/split_args$v
}
# use mc_splits for a random split
split_args$times <- 1
split_args$v <- NULL
split_args$repeats <- NULL
split_inner <- rlang::inject(
mc_splits(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
class(split_inner) <- c(class_inner, class(x))
split_inner
}

#' @rdname inner_split
#' @export
inner_split.group_vfold_split <- function(x, split_args, ...) {
check_dots_empty()

analysis_set <- analysis(x)

# TODO should this be done outside of rsample,
# in workflows or tune?
if (is.null(split_args$prop)) {
split_args$prop <- 1 - 1/split_args$v
}

# use group_mc_splits for a random split
split_args$times <- 1
split_args$v <- NULL
split_args$repeats <- NULL
split_args$balance <- NULL
split_inner <- rlang::inject(
group_mc_splits(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
class(split_inner) <- c(class_inner, class(x))
split_inner
}

# clustering -------------------------------------------------------------

#' @rdname inner_split
#' @export
inner_split.clustering_split <- function(x, split_args, ...) {
check_dots_empty()

analysis_set <- analysis(x)

# TODO: reduce the number of clusters by 1 in tune?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the basic idea of clustering_cv() is to use one cluster as the assessment set, I would reduce v by one for the inner split, so that the cluster left out for the inner split is more likely to be similar to one of the original clusters. If we use the same v, the inner clustering is likely to break up the v-1 clusters in this (outer) analysis set.

I would put that into tune though, not here.

split_args$repeats <- 1
split_inner <- rlang::inject(
clustering_cv(analysis_set, !!!split_args)
)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
class(split_inner) <- c(class_inner, class(x))
split_inner
}


# apparent ---------------------------------------------------------------

#' @rdname inner_split
#' @export
inner_split.apparent_split <- function(x, ...) {
check_dots_empty()

analysis_set <- analysis(x)

split_inner <- apparent(analysis_set)
split_inner <- split_inner$splits[[1]]

class_inner <- paste0(class(x)[1], "_inner")
class(split_inner) <- c(class_inner, class(x))
split_inner
}
30 changes: 21 additions & 9 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,18 @@ reshuffle_rset <- function(rset) {
}
}

arguments <- attributes(rset)
useful_arguments <- names(formals(arguments$class[[1]]))
useful_arguments <- arguments[useful_arguments]
useful_arguments <- useful_arguments[!is.na(names(useful_arguments))]
if (identical(useful_arguments$strata, FALSE)) {
useful_arguments$strata <- NULL
} else if (identical(useful_arguments$strata, TRUE)) {
rset_type <- class(rset)[[1]]
split_arguments <- .get_split_args(rset)
if (identical(split_arguments$strata, TRUE)) {
rlang::abort(
"Cannot reshuffle this rset (`attr(rset, 'strata')` is `TRUE`, not a column identifier)",
i = "If the original object was created with an older version of rsample, try recreating it with the newest version of the package"
)
}

do.call(
arguments$class[[1]],
c(list(data = rset$splits[[1]]$data), useful_arguments)
rset_type,
c(list(data = rset$splits[[1]]$data), split_arguments)
)
}

Expand All @@ -299,6 +295,22 @@ non_random_classes <- c(
"validation_set"
)

#' Get the split arguments from an rset
#' @param rset An `rset` object.
#' @return A list of arguments used to create the rset.
#' @keywords internal
#' @export
.get_split_args <- function(rset) {
all_attributes <- attributes(rset)
args <- names(formals(all_attributes$class[[1]]))
split_args <- all_attributes[args]
split_args <- split_args[!is.na(names(split_args))]
if (identical(split_args$strata, FALSE)) {
split_args$strata <- NULL
}
split_args
}

#' Retrieve individual rsplits objects from an rset
#'
#' @param x The `rset` object to retrieve an rsplit from.
Expand Down
18 changes: 18 additions & 0 deletions man/dot-get_split_args.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions man/inner_split.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading