Skip to content

Commit

Permalink
Implement only simple case for linear_pool_sample()
Browse files Browse the repository at this point in the history
  • Loading branch information
lshandross committed Nov 6, 2024
1 parent 4926a5f commit afaf5e0
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 110 deletions.
129 changes: 40 additions & 89 deletions R/linear_pool_sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,35 @@ linear_pool_sample <- function(model_out_tbl, weights = NULL,
model_id = "hub-ensemble",
task_id_cols = NULL,
n_output_samples = NULL) {
if (!is.null(n_output_samples) && !is.numeric(n_output_samples) && trunc(n_output_samples) != n_output_samples) {
cli::cli_abort("{.arg n_output_samples} must be {.val NULL} or coerceable to an integer")
}

if (!is.null(weights) && is.null(n_output_samples)) {
cli::cli_abort("Component model weights output samples provided,
so a number of ensemble samples {.arg n_output_samples} must be provided")
}
validate_sample_inputs(weights, weights_col_name, n_output_samples)

if (!is.null(weights) && !all(colnames(weights) %in% c("model_id", weights_col_name))) {
cli::cli_abort("Currently weights for different task IDs are not supported for the sample output type.")
num_models <- length(unique(model_out_tbl$model_id))
samples_per_model <- model_out_tbl |>
dplyr::group_by(dplyr::across("model_id")) |>
dplyr::summarize(provided_n_component_samples = dplyr::n()) |>
dplyr::ungroup() |>
dplyr::select("model_id", "provided_n_component_samples") |>
dplyr::distinct(.keep_all = TRUE)
unique_provided_samples <- unique(samples_per_model[["provided_n_component_samples"]])

if (is.null(weights)) {
weights <- data.frame(
model_id = unique(model_out_tbl$model_id),
weight = 1 / num_models,
stringsAsFactors = FALSE
)
weights_col_name <- "weight"
}
unique_weights <- unique(weights[[weights_col_name]])

if (!is.null(n_output_samples)) {
model_out_tbl <- model_out_tbl |>
subset_samples_stratified(weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
n_output_samples = n_output_samples)
if (length(unique_weights) != 1 || length(unique_provided_samples) != 1 || !is.null(n_output_samples)) {
cli::cli_abort(
"The requested ensemble calculation doesn't satisfy all conditions:
1) {.arg model_out_tbl} contains the same number of samples from each component model,
2) {.arg weights} are {.val NULL} or equal for every model,
3) {.arg n_output_samples} = {.val NULL}"
)
}

model_out_tbl |>
Expand Down Expand Up @@ -87,91 +97,32 @@ make_sample_indices_unique <- function(model_out_tbl) {
}


#' Helper function for subsetting model outputs of the sample type by taking a
#' stratified sample across models
#' Perform simple validations on the inputs used to calculate a linear pool
#' of samples
#'
#' @param model_out_tbl an object of class `model_out_tbl` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. The default is `NULL`, in which case an equally-weighted
#' ensemble is calculated. Should be prevalidated.
#' model weights. Default to `NULL`, which specifies an equally-weighted ensemble
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`
#' @param task_id_cols `character` vector with names of columns in
#' `model_out_tbl` that specify modeling tasks.
#' @param n_output_samples `numeric` that specifies how many sample forecasts to
#' return per unique combination of task IDs.
#'
#' @noRd
#' return per unique combination of task IDs. Defaults to NULL, in which case
#' all provided component model samples are collected and returned.
#'
#' @return a `model_out_tbl` object of ensemble predictions for the `sample`
#' output type. Note that the output type ID values will not match those of the
#' input model_out_tbl but do preserve relationships across unique task ID combos
#' @return no return value
#'
#' @importFrom rlang .data
subset_samples_stratified <- function(model_out_tbl, weights = NULL,
weights_col_name = "weight",
task_id_cols,
n_output_samples) {
num_models <- length(unique(model_out_tbl$model_id))
samples_per_model <- model_out_tbl |>
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", task_id_cols)))) |>
dplyr::summarize(provided_n_component_samples = dplyr::n()) |>
dplyr::ungroup() |>
dplyr::select("model_id", "provided_n_component_samples") |>
dplyr::distinct(.keep_all = TRUE) # assumes same number per task ID combo

if (is.null(weights)) {
weights <- data.frame(
model_id = unique(model_out_tbl$model_id),
weight = 1 / num_models,
stringsAsFactors = FALSE
)
weights_col_name <- "weight"
#' @noRd
validate_sample_inputs <- function(weights = NULL, weights_col_name = "weight", n_output_samples = NULL) {
if (!is.null(n_output_samples) && !is.numeric(n_output_samples) && trunc(n_output_samples) != n_output_samples) {
cli::cli_abort("{.arg n_output_samples} must be {.val NULL} or an integer value")
}

samples_per_model <- samples_per_model |>
dplyr::left_join(weights, by = "model_id") |>
dplyr::mutate(target_n_component_samples = floor(.data[[weights_col_name]] * n_output_samples))
remainder_samples <- n_output_samples - sum(samples_per_model$target_n_component_samples)
remainder_model_indices <- sample(x = 1:num_models, size = remainder_samples)
samples_per_model$target_n_component_samples[remainder_model_indices] + 1

if (!length(unique(samples_per_model$provided_n_component_samples)) != 1 && is.null(n_output_samples)) {
cli::cli_abort("Component model provided differing numbers of samples within at least one forecast task id group,
if (!is.null(weights) && is.null(n_output_samples)) {
cli::cli_abort("Component model weights output samples provided,
so a number of ensemble samples {.arg n_output_samples} must be provided")
}

# iterate over component models and sample as requested
split_models <- model_out_tbl |>
dplyr::mutate(output_type_id = as.character(.data[["output_type_id"]])) |>
split(f = model_out_tbl$model_id)
model_out_tbl <- split_models |>
purrr::map(.f = function(split_outputs) {
current_model <- split_outputs$model_id[1]
provided_samples <- samples_per_model$provided_n_component_samples[samples_per_model$model_id == current_model]
target_samples <- samples_per_model$target_n_component_samples[samples_per_model$model_id == current_model]
provided_indices <- unique(split_outputs$output_type_id)
if (target_samples > provided_samples) {
# replicate the rows with those generated indices
duplications <- floor(target_samples / provided_samples) # always >= 1
duplicated_outputs <- dplyr::filter(split_outputs, .data[["output_type"]] != "sample")
for (i in 1:duplications) { # indices maintained across task id combos
duplicated_outputs <- split_outputs |>
dplyr::mutate(output_type_id = paste0(.data[["output_type_id"]], i)) |>
dplyr::bind_rows(duplicated_outputs)
}
# sample for the remaining values' indices
sample_index <- sample(x = provided_indices, size = target_samples %% provided_samples, replace = FALSE)
remainder_outputs <- split_outputs |>
dplyr::filter(.data[["output_type_id"]] %in% sample_index)
split_outputs <- dplyr::bind_rows(duplicated_outputs, remainder_outputs)
} else {
sample_index <- sample(x = provided_indices, size = target_samples, replace = FALSE)
split_outputs <- split_outputs |>
dplyr::filter(.data[["output_type_id"]] %in% sample_index)
}
}) |>
purrr::list_rbind()
if (!is.null(weights) && !all(colnames(weights) %in% c("model_id", weights_col_name))) {
cli::cli_abort("Currently weights for different task IDs are not supported for the sample output type.")
}
}
25 changes: 4 additions & 21 deletions tests/testthat/test-linear_pool.R
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ test_that("samples only collected and re-indexed for simplest case", {
})


test_that("ensemble of samples is correctly calculated for more complex cases", {
test_that("ensemble of samples throws an error for the more complex cases", {
sample_outputs <- expand.grid(stringsAsFactors = FALSE,
model_id = letters[1:4],
location = c("222", "888"),
Expand Down Expand Up @@ -492,29 +492,12 @@ test_that("ensemble of samples is correctly calculated for more complex cases",

fweight <- data.frame(model_id = letters[1:4], weight = 0.1 * (1:4))

reindexed_outputs <- sample_outputs |>
dplyr::mutate(output_type_id = paste0(.data[["model_id"]], .data[["output_type_id"]]))
expected_outputs <- reindexed_outputs |>
dplyr::filter(output_type_id %in% c("a2", "a3", "b1", "d1", "d3"))
for (i in 1:2) {
expected_outputs <- reindexed_outputs |>
dplyr::filter(.data[["model_id"]] %in% letters[(i + 1):4]) |>
dplyr::mutate(output_type_id = paste0(.data[["output_type_id"]], i)) |>
dplyr::bind_rows(expected_outputs)
}
expected_outputs <- expected_outputs |>
dplyr::mutate(model_id = "hub-ensemble") |>
dplyr::arrange(.data[["output_type_id"]]) |>
hubUtils::as_model_out_tbl()

set.seed(1234)
actual_outputs <- sample_outputs |>
sample_outputs |>
linear_pool(
weights = fweight,
task_id_cols = c("target_date", "target", "horizon", "location"),
n_output_samples = 20
) |>
dplyr::arrange(.data[["output_type_id"]])

expect_equal(actual_outputs, expected_outputs)
dplyr::arrange(.data[["output_type_id"]]) |>
expect_error("The requested ensemble calculation doesn't satisfy all conditions", fixed = TRUE)
})

0 comments on commit afaf5e0

Please sign in to comment.