Skip to content

Commit

Permalink
Refactor validate_ensemble_outputs() to reduce cyclomatic complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
lshandross committed Dec 3, 2024
1 parent 11f1935 commit f0061e2
Showing 1 changed file with 52 additions and 32 deletions.
84 changes: 52 additions & 32 deletions R/validate_ensemble_inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#' @param valid_output_types `character` vector with the names of valid output
#' types for the particular ensembling method used. See the details for more
#' information.
#'
#' @details If the ensembling function intended to be used is `"simple_ensemble"`,
#' the valid output types are `mean`, `median`, `quantile`, `cdf`, and `pmf`.
#' If the ensembling function will be `"linear_pool"`, the valid output types
Expand Down Expand Up @@ -90,38 +91,7 @@ validate_ensemble_inputs <- function(model_out_tbl, weights = NULL,
}

if (!is.null(weights)) {
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
{.val {req_weight_cols}}."
))
}

weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
"x" = "{.arg weights} included a column named {.val {\"value\"}},
which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_out_tbl)]
if (length(invalid_cols) > 0) {
cli::cli_abort(c(
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_out_tbl}:
{.val {invalid_cols}}"
))
}

if (weights_col_name %in% colnames(model_out_tbl)) {
cli::cli_abort(c(
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}},
is already a column in {.arg model_out_tbl}."
))
}
validate_weights(model_out_cols, weights, weights_col_name)

if (any(c("cdf", "pmf") %in% unique_output_types) && "output_type_id" %in% colnames(weights)) {
# nolint start
Expand All @@ -141,3 +111,53 @@ validate_ensemble_inputs <- function(model_out_tbl, weights = NULL,
comp_unit_cols = comp_unit_cols)
return(validated_inputs)
}


#' Perform basic validations on the model weights used to calculate an ensemble of
#' component model outputs for each combination of model task, output type,
#' and output type id.
#'
#' @param model_out_cols `character` string naming columns in a `model_out_tbl`
#' object of component predictions that will be ensembled using the model weights
#' in `weights`
#' @param weights a `data.frame` of component model weights to be validated. It must
#' contain a `model_id` column and a column giving weights, but may also contain
#' additional columns corresponding to task id variables, `output_type`, or
#' `output_type_id`, if weights are specific to values of those variables.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`
#'
#' @return no return value
#' @noRd

validate_weights <- function(model_out_cols, weights = NULL, weights_col_name = "weight") {
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns {.val {req_weight_cols}}."
))
}

weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
"x" = "{.arg weights} included a column named {.val {\"value\"}}, which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% model_out_cols]
if (length(invalid_cols) > 0) {
cli::cli_abort(c(
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_out_tbl}: {.val {invalid_cols}}"
))
}

if (weights_col_name %in% model_out_cols) {
cli::cli_abort(c(
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}},
is already a column in {.arg model_out_tbl}."
))
}
}

0 comments on commit f0061e2

Please sign in to comment.