From afaf5e02e514b366c90422e8d7ef2ae58f170259 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:05:06 -0500 Subject: [PATCH] Implement only simple case for `linear_pool_sample()` --- R/linear_pool_sample.R | 129 +++++++++--------------------- tests/testthat/test-linear_pool.R | 25 +----- 2 files changed, 44 insertions(+), 110 deletions(-) diff --git a/R/linear_pool_sample.R b/R/linear_pool_sample.R index 9835c54..3b64e6a 100644 --- a/R/linear_pool_sample.R +++ b/R/linear_pool_sample.R @@ -25,25 +25,35 @@ linear_pool_sample <- function(model_out_tbl, weights = NULL, model_id = "hub-ensemble", task_id_cols = NULL, n_output_samples = NULL) { - if (!is.null(n_output_samples) && !is.numeric(n_output_samples) && trunc(n_output_samples) != n_output_samples) { - cli::cli_abort("{.arg n_output_samples} must be {.val NULL} or coerceable to an integer") - } - if (!is.null(weights) && is.null(n_output_samples)) { - cli::cli_abort("Component model weights output samples provided, - so a number of ensemble samples {.arg n_output_samples} must be provided") - } + validate_sample_inputs(weights, weights_col_name, n_output_samples) - if (!is.null(weights) && !all(colnames(weights) %in% c("model_id", weights_col_name))) { - cli::cli_abort("Currently weights for different task IDs are not supported for the sample output type.") + num_models <- length(unique(model_out_tbl$model_id)) + samples_per_model <- model_out_tbl |> + dplyr::group_by(dplyr::across("model_id")) |> + dplyr::summarize(provided_n_component_samples = dplyr::n()) |> + dplyr::ungroup() |> + dplyr::select("model_id", "provided_n_component_samples") |> + dplyr::distinct(.keep_all = TRUE) + unique_provided_samples <- unique(samples_per_model[["provided_n_component_samples"]]) + + if (is.null(weights)) { + weights <- data.frame( + model_id = unique(model_out_tbl$model_id), + weight = 1 / num_models, + stringsAsFactors = FALSE + ) + weights_col_name <- "weight" } + unique_weights <- unique(weights[[weights_col_name]]) - if (!is.null(n_output_samples)) { - model_out_tbl <- model_out_tbl |> - subset_samples_stratified(weights = weights, - weights_col_name = weights_col_name, - task_id_cols = task_id_cols, - n_output_samples = n_output_samples) + if (length(unique_weights) != 1 || length(unique_provided_samples) != 1 || !is.null(n_output_samples)) { + cli::cli_abort( + "The requested ensemble calculation doesn't satisfy all conditions: + 1) {.arg model_out_tbl} contains the same number of samples from each component model, + 2) {.arg weights} are {.val NULL} or equal for every model, + 3) {.arg n_output_samples} = {.val NULL}" + ) } model_out_tbl |> @@ -87,91 +97,32 @@ make_sample_indices_unique <- function(model_out_tbl) { } -#' Helper function for subsetting model outputs of the sample type by taking a -#' stratified sample across models +#' Perform simple validations on the inputs used to calculate a linear pool +#' of samples #' -#' @param model_out_tbl an object of class `model_out_tbl` with component -#' model outputs (e.g., predictions). #' @param weights an optional `data.frame` with component model weights. If #' provided, it should have a column named `model_id` and a column containing -#' model weights. The default is `NULL`, in which case an equally-weighted -#' ensemble is calculated. Should be prevalidated. +#' model weights. Default to `NULL`, which specifies an equally-weighted ensemble #' @param weights_col_name `character` string naming the column in `weights` #' with model weights. Defaults to `"weight"` -#' @param task_id_cols `character` vector with names of columns in -#' `model_out_tbl` that specify modeling tasks. #' @param n_output_samples `numeric` that specifies how many sample forecasts to -#' return per unique combination of task IDs. -#' -#' @noRd +#' return per unique combination of task IDs. Defaults to NULL, in which case +#' all provided component model samples are collected and returned. #' -#' @return a `model_out_tbl` object of ensemble predictions for the `sample` -#' output type. Note that the output type ID values will not match those of the -#' input model_out_tbl but do preserve relationships across unique task ID combos +#' @return no return value #' -#' @importFrom rlang .data -subset_samples_stratified <- function(model_out_tbl, weights = NULL, - weights_col_name = "weight", - task_id_cols, - n_output_samples) { - num_models <- length(unique(model_out_tbl$model_id)) - samples_per_model <- model_out_tbl |> - dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", task_id_cols)))) |> - dplyr::summarize(provided_n_component_samples = dplyr::n()) |> - dplyr::ungroup() |> - dplyr::select("model_id", "provided_n_component_samples") |> - dplyr::distinct(.keep_all = TRUE) # assumes same number per task ID combo - - if (is.null(weights)) { - weights <- data.frame( - model_id = unique(model_out_tbl$model_id), - weight = 1 / num_models, - stringsAsFactors = FALSE - ) - weights_col_name <- "weight" +#' @noRd +validate_sample_inputs <- function(weights = NULL, weights_col_name = "weight", n_output_samples = NULL) { + if (!is.null(n_output_samples) && !is.numeric(n_output_samples) && trunc(n_output_samples) != n_output_samples) { + cli::cli_abort("{.arg n_output_samples} must be {.val NULL} or an integer value") } - samples_per_model <- samples_per_model |> - dplyr::left_join(weights, by = "model_id") |> - dplyr::mutate(target_n_component_samples = floor(.data[[weights_col_name]] * n_output_samples)) - remainder_samples <- n_output_samples - sum(samples_per_model$target_n_component_samples) - remainder_model_indices <- sample(x = 1:num_models, size = remainder_samples) - samples_per_model$target_n_component_samples[remainder_model_indices] + 1 - - if (!length(unique(samples_per_model$provided_n_component_samples)) != 1 && is.null(n_output_samples)) { - cli::cli_abort("Component model provided differing numbers of samples within at least one forecast task id group, + if (!is.null(weights) && is.null(n_output_samples)) { + cli::cli_abort("Component model weights output samples provided, so a number of ensemble samples {.arg n_output_samples} must be provided") } - # iterate over component models and sample as requested - split_models <- model_out_tbl |> - dplyr::mutate(output_type_id = as.character(.data[["output_type_id"]])) |> - split(f = model_out_tbl$model_id) - model_out_tbl <- split_models |> - purrr::map(.f = function(split_outputs) { - current_model <- split_outputs$model_id[1] - provided_samples <- samples_per_model$provided_n_component_samples[samples_per_model$model_id == current_model] - target_samples <- samples_per_model$target_n_component_samples[samples_per_model$model_id == current_model] - provided_indices <- unique(split_outputs$output_type_id) - if (target_samples > provided_samples) { - # replicate the rows with those generated indices - duplications <- floor(target_samples / provided_samples) # always >= 1 - duplicated_outputs <- dplyr::filter(split_outputs, .data[["output_type"]] != "sample") - for (i in 1:duplications) { # indices maintained across task id combos - duplicated_outputs <- split_outputs |> - dplyr::mutate(output_type_id = paste0(.data[["output_type_id"]], i)) |> - dplyr::bind_rows(duplicated_outputs) - } - # sample for the remaining values' indices - sample_index <- sample(x = provided_indices, size = target_samples %% provided_samples, replace = FALSE) - remainder_outputs <- split_outputs |> - dplyr::filter(.data[["output_type_id"]] %in% sample_index) - split_outputs <- dplyr::bind_rows(duplicated_outputs, remainder_outputs) - } else { - sample_index <- sample(x = provided_indices, size = target_samples, replace = FALSE) - split_outputs <- split_outputs |> - dplyr::filter(.data[["output_type_id"]] %in% sample_index) - } - }) |> - purrr::list_rbind() + if (!is.null(weights) && !all(colnames(weights) %in% c("model_id", weights_col_name))) { + cli::cli_abort("Currently weights for different task IDs are not supported for the sample output type.") + } } diff --git a/tests/testthat/test-linear_pool.R b/tests/testthat/test-linear_pool.R index 1a4dcdf..e8ea74b 100644 --- a/tests/testthat/test-linear_pool.R +++ b/tests/testthat/test-linear_pool.R @@ -460,7 +460,7 @@ test_that("samples only collected and re-indexed for simplest case", { }) -test_that("ensemble of samples is correctly calculated for more complex cases", { +test_that("ensemble of samples throws an error for the more complex cases", { sample_outputs <- expand.grid(stringsAsFactors = FALSE, model_id = letters[1:4], location = c("222", "888"), @@ -492,29 +492,12 @@ test_that("ensemble of samples is correctly calculated for more complex cases", fweight <- data.frame(model_id = letters[1:4], weight = 0.1 * (1:4)) - reindexed_outputs <- sample_outputs |> - dplyr::mutate(output_type_id = paste0(.data[["model_id"]], .data[["output_type_id"]])) - expected_outputs <- reindexed_outputs |> - dplyr::filter(output_type_id %in% c("a2", "a3", "b1", "d1", "d3")) - for (i in 1:2) { - expected_outputs <- reindexed_outputs |> - dplyr::filter(.data[["model_id"]] %in% letters[(i + 1):4]) |> - dplyr::mutate(output_type_id = paste0(.data[["output_type_id"]], i)) |> - dplyr::bind_rows(expected_outputs) - } - expected_outputs <- expected_outputs |> - dplyr::mutate(model_id = "hub-ensemble") |> - dplyr::arrange(.data[["output_type_id"]]) |> - hubUtils::as_model_out_tbl() - - set.seed(1234) - actual_outputs <- sample_outputs |> + sample_outputs |> linear_pool( weights = fweight, task_id_cols = c("target_date", "target", "horizon", "location"), n_output_samples = 20 ) |> - dplyr::arrange(.data[["output_type_id"]]) - - expect_equal(actual_outputs, expected_outputs) + dplyr::arrange(.data[["output_type_id"]]) |> + expect_error("The requested ensemble calculation doesn't satisfy all conditions", fixed = TRUE) })