From f0061e26a69f3d0817e43e832100e258eb8c4466 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:26:56 -0500 Subject: [PATCH] Refactor `validate_ensemble_outputs()` to reduce cyclomatic complexity --- R/validate_ensemble_inputs.R | 84 ++++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/R/validate_ensemble_inputs.R b/R/validate_ensemble_inputs.R index daa8c0d..a2823d0 100644 --- a/R/validate_ensemble_inputs.R +++ b/R/validate_ensemble_inputs.R @@ -25,6 +25,7 @@ #' @param valid_output_types `character` vector with the names of valid output #' types for the particular ensembling method used. See the details for more #' information. +#' #' @details If the ensembling function intended to be used is `"simple_ensemble"`, #' the valid output types are `mean`, `median`, `quantile`, `cdf`, and `pmf`. #' If the ensembling function will be `"linear_pool"`, the valid output types @@ -90,38 +91,7 @@ validate_ensemble_inputs <- function(model_out_tbl, weights = NULL, } if (!is.null(weights)) { - req_weight_cols <- c("model_id", weights_col_name) - if (!all(req_weight_cols %in% colnames(weights))) { - cli::cli_abort(c( - "x" = "{.arg weights} did not include required columns - {.val {req_weight_cols}}." - )) - } - - weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] - - if ("value" %in% weight_by_cols) { - cli::cli_abort(c( - "x" = "{.arg weights} included a column named {.val {\"value\"}}, - which is not allowed." - )) - } - - invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_out_tbl)] - if (length(invalid_cols) > 0) { - cli::cli_abort(c( - "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that - {?was/were} not present in {.arg model_out_tbl}: - {.val {invalid_cols}}" - )) - } - - if (weights_col_name %in% colnames(model_out_tbl)) { - cli::cli_abort(c( - "x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}}, - is already a column in {.arg model_out_tbl}." - )) - } + validate_weights(model_out_cols, weights, weights_col_name) if (any(c("cdf", "pmf") %in% unique_output_types) && "output_type_id" %in% colnames(weights)) { # nolint start @@ -141,3 +111,53 @@ validate_ensemble_inputs <- function(model_out_tbl, weights = NULL, comp_unit_cols = comp_unit_cols) return(validated_inputs) } + + +#' Perform basic validations on the model weights used to calculate an ensemble of +#' component model outputs for each combination of model task, output type, +#' and output type id. +#' +#' @param model_out_cols `character` string naming columns in a `model_out_tbl` +#' object of component predictions that will be ensembled using the model weights +#' in `weights` +#' @param weights a `data.frame` of component model weights to be validated. It must +#' contain a `model_id` column and a column giving weights, but may also contain +#' additional columns corresponding to task id variables, `output_type`, or +#' `output_type_id`, if weights are specific to values of those variables. +#' @param weights_col_name `character` string naming the column in `weights` +#' with model weights. Defaults to `"weight"` +#' +#' @return no return value +#' @noRd + +validate_weights <- function(model_out_cols, weights = NULL, weights_col_name = "weight") { + req_weight_cols <- c("model_id", weights_col_name) + if (!all(req_weight_cols %in% colnames(weights))) { + cli::cli_abort(c( + "x" = "{.arg weights} did not include required columns {.val {req_weight_cols}}." + )) + } + + weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] + + if ("value" %in% weight_by_cols) { + cli::cli_abort(c( + "x" = "{.arg weights} included a column named {.val {\"value\"}}, which is not allowed." + )) + } + + invalid_cols <- weight_by_cols[!weight_by_cols %in% model_out_cols] + if (length(invalid_cols) > 0) { + cli::cli_abort(c( + "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that + {?was/were} not present in {.arg model_out_tbl}: {.val {invalid_cols}}" + )) + } + + if (weights_col_name %in% model_out_cols) { + cli::cli_abort(c( + "x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}}, + is already a column in {.arg model_out_tbl}." + )) + } +}