Skip to content

Commit

Permalink
simplify concept of outcome type in the package (#14)
Browse files Browse the repository at this point in the history
* remove `container(mode)`
* rename `adjust_*_calibration(type)` to `adjust_*_calibration(method)`
  • Loading branch information
simonpcouch authored May 2, 2024
1 parent 071749e commit 1412531
Show file tree
Hide file tree
Showing 35 changed files with 248 additions and 177 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: container
Title: Sandbox for a postprocessor object
Version: 0.0.0.9000
Version: 0.0.0.9001
Authors@R: c(
person("Simon", "Couch", , "[email protected]", role = "aut"),
person("Hannah", "Frick", , "[email protected]", role = "aut"),
Expand Down
3 changes: 1 addition & 2 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' container() %>%
#' adjust_equivocal_zone(value = 1 / 4)
#'
#'
Expand Down Expand Up @@ -43,7 +43,6 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
21 changes: 10 additions & 11 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Re-calibrate numeric predictions
#'
#' @param x A [container()].
#' @param type Character. One of `"linear"`, `"isotonic"`, or
#' @param method Character. One of `"linear"`, `"isotonic"`, or
#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
#' package [probably::cal_estimate_linear()],
#' [probably::cal_estimate_isotonic()], or
Expand All @@ -19,21 +19,21 @@
#'
#' # specify calibration
#' reg_ctr <-
#' container(mode = "regression") %>%
#' adjust_numeric_calibration(type = "linear")
#' container() %>%
#' adjust_numeric_calibration(method = "linear")
#'
#' # train container
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
#'
#' predict(reg_ctr_trained, dat)
#' @export
adjust_numeric_calibration <- function(x, type = NULL) {
adjust_numeric_calibration <- function(x, method = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "numeric")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
# wait to `check_method()` until `fit()` time
if (!is.null(method)) {
arg_match0(
type,
method,
c("linear", "isotonic", "isotonic_boot")
)
}
Expand All @@ -43,13 +43,12 @@ adjust_numeric_calibration <- function(x, type = NULL) {
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
arguments = list(type = type),
arguments = list(method = method),
results = list(),
trained = FALSE
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand All @@ -67,13 +66,13 @@ print.numeric_calibration <- function(x, ...) {

#' @export
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
method <- check_method(object$method, container$type)
# todo: adjust_numeric_calibration() should take arguments to pass to
# cal_estimate_* via dots
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
paste0("cal_estimate_", method),
.data = data,
truth = container$columns$outcome,
estimate = container$columns$estimate,
Expand Down
1 change: 0 additions & 1 deletion R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
3 changes: 1 addition & 2 deletions R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' container() %>%
#' adjust_equivocal_zone() %>%
#' adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2))
#'
Expand Down Expand Up @@ -39,7 +39,6 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
17 changes: 8 additions & 9 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#' Re-calibrate classification probability predictions
#'
#' @param x A [container()].
#' @param type Character. One of `"logistic"`, `"multinomial"`,
#' @param method Character. One of `"logistic"`, `"multinomial"`,
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
#' [probably::cal_estimate_multinomial()], etc., respectively.
#' @export
adjust_probability_calibration <- function(x, type = NULL) {
adjust_probability_calibration <- function(x, method = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "probability")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
# wait to `check_method()` until `fit()` time
if (!is.null(method)) {
arg_match(
type,
method,
c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot")
)
}
Expand All @@ -22,13 +22,12 @@ adjust_probability_calibration <- function(x, type = NULL) {
"probability_calibration",
inputs = "probability",
outputs = "probability_class",
arguments = list(type = type),
arguments = list(method = method),
results = list(),
trained = FALSE
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand All @@ -46,14 +45,14 @@ print.probability_calibration <- function(x, ...) {

#' @export
fit.probability_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
method <- check_method(object$method, container$type)
# todo: adjust_probability_calibration() should take arguments to pass to
# cal_estimate_* via dots
# to-do: add argument specifying `prop` in initial_split
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
paste0("cal_estimate_", method),
.data = data,
# todo: make getters for the entries in `columns`
truth = container$columns$outcome,
Expand Down
3 changes: 1 addition & 2 deletions R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' container() %>%
#' adjust_probability_threshold(threshold = .1)
#'
#' two_class_example %>% count(predicted)
Expand Down Expand Up @@ -39,7 +39,6 @@ adjust_probability_threshold <- function(x, threshold = 0.5) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
20 changes: 5 additions & 15 deletions R/container.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#' Declare post-processing for model predictions
#'
#' @param mode The model's mode, one of `"classification"`, or `"regression"`.
#' Modes of `"censored regression"` are not currently supported.
#' @param type The model sub-type. Possible values are `"unknown"`, `"regression"`,
#' `"binary"`, or `"multiclass"`.
#' @param outcome The name of the outcome variable.
Expand All @@ -14,9 +12,9 @@
#' @param time The name of the predicted event time. (not yet supported)
#' @examples
#'
#' container(mode = "regression")
#' container()
#' @export
container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
container <- function(type = "unknown", outcome = NULL, estimate = NULL,
probabilities = NULL, time = NULL) {
columns <-
list(
Expand All @@ -28,7 +26,6 @@ container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
)

new_container(
mode,
type,
operations = list(),
columns = columns,
Expand All @@ -37,13 +34,7 @@ container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
)
}

new_container <- function(mode, type, operations, columns, ptype, call) {
mode <- arg_match0(mode, c("regression", "classification"))

if (mode == "regression") {
type <- "regression"
}

new_container <- function(type, operations, columns, ptype, call) {
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))

if (!is.list(operations)) {
Expand All @@ -58,11 +49,11 @@ new_container <- function(mode, type, operations, columns, ptype, call) {
}

# validate operation order and check duplicates
validate_order(operations, mode, call)
validate_order(operations, type, call)

# check columns
res <- list(
mode = mode, type = type, operations = operations,
type = type, operations = operations,
columns = columns, ptype = ptype
)
class(res) <- "container"
Expand Down Expand Up @@ -120,7 +111,6 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),
object <- set_container_type(object, .data[[columns$outcome]])

object <- new_container(
object$mode,
object$type,
operations = object$operations,
columns = columns,
Expand Down
37 changes: 18 additions & 19 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ check_container <- function(x, calibration_type = NULL, call = caller_env(), arg
# check that the type of calibration ("numeric" or "probability") is
# compatible with the container type
if (!is.null(calibration_type)) {
container_type <- x$type
type <- x$type
switch(
container_type,
type,
regression =
check_calibration_type(calibration_type, "numeric", container_type, call = call),
binary = , multinomial =
check_calibration_type(calibration_type, "probability", container_type, call = call)
check_calibration_type(calibration_type, "numeric", type, call = call),
binary = , multiclass =
check_calibration_type(calibration_type, "probability", type, call = call)
)
}

Expand All @@ -90,54 +90,53 @@ types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
# a check function to be called when a container is being `fit()`ted.
# by the time a container is fitted, we have:
# * `adjust_type`, the `type` argument passed to an `adjust_*` function
# * `method`, the `method` argument passed to an `adjust_*` function
# * this argument has already been checked to agree with the kind of
# `adjust_*()` function via `arg_match0()`.
# * `container_type`, the `type` argument either specified in `container()`
# or inferred in `fit.container()`.
check_type <- function(adjust_type,
container_type,
arg = caller_arg(adjust_type),
check_method <- function(method,
type,
arg = caller_arg(method),
call = caller_env()) {
# if no `adjust_type` was supplied, infer a reasonable one based on the
# `container_type`
if (is.null(adjust_type)) {
# if no `method` was supplied, infer a reasonable one based on the `type`
if (is.null(method)) {
switch(
container_type,
type,
regression = return("linear"),
binary = return("logistic"),
multiclass = return("multinomial")
)
}

switch(
container_type,
type,
regression = arg_match0(
adjust_type,
method,
types_regression,
arg_nm = arg,
error_call = call
),
binary = arg_match0(
adjust_type,
method,
types_binary,
arg_nm = arg,
error_call = call
),
multiclass = arg_match0(
adjust_type,
method,
types_multiclass,
arg_nm = arg,
error_call = call
),
arg_match0(
adjust_type,
method,
unique(c(types_regression, types_binary, types_multiclass)),
arg_nm = arg,
error_call = call
)
)

adjust_type
method
}

31 changes: 26 additions & 5 deletions R/validation-rules.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
validate_order <- function(ops, mode, call) {
validate_order <- function(ops, type, call = caller_env()) {
orderings <-
tibble::new_tibble(list(
name = purrr::map_chr(ops, ~ class(.x)[1]),
Expand All @@ -13,12 +13,17 @@ validate_order <- function(ops, mode, call) {
return(invisible(orderings))
}

if (mode == "classification") {
check_classification_order(orderings, call)
} else {
check_regression_order(orderings, call)
if (type == "unknown") {
type <- infer_type(orderings)
}

switch(
type,
regression = check_regression_order(orderings, call),
binary = , multiclass = check_classification_order(orderings, call),
invisible()
)

invisible(orderings)
}

Expand Down Expand Up @@ -83,3 +88,19 @@ check_duplicates <- function(x, call) {
}
invisible(x)
}

infer_type <- function(orderings) {
if (all(orderings$output_all)) {
return("unknown")
}

if (all(orderings$output_numeric | orderings$output_all)) {
return("regression")
}

if (all(orderings$output_prob | orderings$output_class | orderings$output_all)) {
return("binary")
}

"unknown"
}
2 changes: 1 addition & 1 deletion inst/examples/container_regression_example.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ We could manually use `cal_apply()` to adjust predictions, but instead, we'll ad
#| label: post-1
post_obj <-
container(mode = "regression") %>%
container() %>%
adjust_numeric_calibration(bst_cal)
post_obj
```
Expand Down
Loading

0 comments on commit 1412531

Please sign in to comment.