Skip to content

Commit

Permalink
fit calibrators at fit.container()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Apr 26, 2024
1 parent 3874410 commit 2821988
Show file tree
Hide file tree
Showing 14 changed files with 173 additions and 109 deletions.
47 changes: 27 additions & 20 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#' Re-calibrate numeric predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_linear()].
#' @param type Character. One of `"linear"`, `"isotonic"`, or
#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
#' package [probably::cal_estimate_linear()],
#' [probably::cal_estimate_isotonic()], or
#' [probably::cal_estimate_isotonic_boot()], respectively.
#' @examples
#' library(modeldata)
#' library(probably)
Expand All @@ -14,36 +17,27 @@
#'
#' dat
#'
#' # calibrate numeric predictions
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred)
#'
#' # specify calibration
#' reg_ctr <-
#' container(mode = "regression") %>%
#' adjust_numeric_calibration(reg_cal)
#' adjust_numeric_calibration(type = "linear")
#'
#' # "train" container
#' # train container
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
#'
#' predict(reg_ctr, dat)
#' predict(reg_ctr_trained, dat)
#' @export
adjust_numeric_calibration <- function(x, calibrator) {
adjust_numeric_calibration <- function(x, type = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x)
check_required(calibrator)
if (!inherits(calibrator, "cal_regression")) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\
not {.obj_type_friendly {calibrator}}."
)
}
type <- check_type(type, x$type)

op <-
new_operation(
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
arguments = list(calibrator = calibrator),
arguments = list(type = type),
results = list(),
trained = FALSE
)
Expand All @@ -67,19 +61,32 @@ print.numeric_calibration <- function(x, ...) {

#' @export
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
# todo: adjust_numeric_calibration() should take arguments to pass to
# cal_estimate_* via dots
fit <-
eval_bare(
call2(
paste0("cal_estimate_", object$arguments$type),
.data = data,
truth = container$columns$outcome,
estimate = container$columns$estimate,
.ns = "probably"
)
)

new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
results = list(fit = fit),
trained = TRUE
)
}

#' @export
predict.numeric_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
probably::cal_apply(new_data, object$results$fit)
}

# todo probably needs required_pkgs methods for cal objects
Expand Down
40 changes: 25 additions & 15 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
#' Re-calibrate classification probability predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_logistic()].
#' @param type Character. One of `"logistic"`, `"multinomial"`,
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
#' [probably::cal_estimate_multinomial()], etc., respectively.
#' @export
adjust_probability_calibration <- function(x, calibrator) {
adjust_probability_calibration <- function(x, type = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x)
cls <- c("cal_binary", "cal_multinomial")
check_required(calibrator)
if (!inherits_any(calibrator, cls)) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_binary> or <cal_multinomial> object](probably::cal_estimate_logistic)}, \\
not {.obj_type_friendly {calibrator}}."
)
}
type <- check_type(type, x$type)

op <-
new_operation(
"probability_calibration",
inputs = "probability",
outputs = "probability_class",
arguments = list(calibrator = calibrator),
arguments = list(type = type),
results = list(),
trained = FALSE
)
Expand All @@ -45,19 +40,34 @@ print.probability_calibration <- function(x, ...) {

#' @export
fit.probability_calibration <- function(object, data, container = NULL, ...) {
# todo: adjust_probability_calibration() should take arguments to pass to
# cal_estimate_* via dots
# to-do: add argument specifying `prop` in initial_split
fit <-
eval_bare(
call2(
paste0("cal_estimate_", object$type),
.data = data,
# todo: make getters for the entries in `columns`
truth = container$columns$outcome,
estimate = container$columns$estimate,
.ns = "probably"
)
)

new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
results = list(fit = fit),
trained = TRUE
)
}

#' @export
predict.probability_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
probably::cal_apply(new_data, object$results$fit)
}

# todo probably needs required_pkgs methods for cal objects
Expand Down
2 changes: 1 addition & 1 deletion R/container.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),

num_oper <- length(object$operations)
for (op in seq_len(num_oper)) {
object$operations[[op]] <- fit(object$operations[[op]], data, object)
object$operations[[op]] <- fit(object$operations[[op]], .data, object)
.data <- predict(object$operations[[op]], .data, object)
}

Expand Down
49 changes: 49 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,52 @@ check_container <- function(x, call = caller_env(), arg = caller_arg(x)) {

invisible()
}

types_regression <- c("linear", "isotonic", "isotonic_boot")
types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
check_type <- function(adjust_type,
container_type,
arg = caller_arg(adjust_type),
call = caller_env()) {
# to-do: handle unknown container type (#11 ish)
if (is.null(adjust_type)) {
switch(
container_type,
regression = return("linear"),
binary = return("logistic"),
multiclass = return("multinomial")
)
}

switch(
container_type,
regression = arg_match0(
adjust_type,
types_regression,
arg_nm = arg,
error_call = call
),
binary = arg_match0(
adjust_type,
types_binary,
arg_nm = arg,
error_call = call
),
multiclass = arg_match0(
adjust_type,
types_multiclass,
arg_nm = arg,
error_call = call
),
arg_match0(
adjust_type,
unique(c(types_regression, types_binary, types_multiclass)),
arg_nm = arg,
error_call = call
)
)

adjust_type
}

18 changes: 9 additions & 9 deletions man/adjust_numeric_calibration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions man/adjust_probability_calibration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 5 additions & 13 deletions tests/testthat/_snaps/adjust-numeric-calibration.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
# adjustment printing

Code
ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal)
ctr_reg %>% adjust_numeric_calibration()
Message
-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:
* Re-calibrate numeric predictions.

# errors informatively with bad input

Code
adjust_numeric_calibration(ctr_reg)
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` is absent but must be supplied.

---

Code
adjust_numeric_calibration(ctr_reg, "boop")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a string.
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "boop".

---

Code
adjust_numeric_calibration(ctr_cls, dummy_cls_cal)
container("classification", "binary") %>% adjust_numeric_calibration("linear")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a <cal_binary> object.
! `type` must be one of "logistic", "beta", "isotonic", or "isotonic_boot", not "linear".

8 changes: 4 additions & 4 deletions tests/testthat/_snaps/adjust-numeric-range.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message
-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:
* Constrain numeric predictions to be between [-Inf, Inf].

Expand All @@ -16,7 +16,7 @@
Message
-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:
* Constrain numeric predictions to be between [?, Inf].

Expand All @@ -27,7 +27,7 @@
Message
-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:
* Constrain numeric predictions to be between [-1, ?].

Expand All @@ -38,7 +38,7 @@
Message
-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:
* Constrain numeric predictions to be between [?, 1].

Loading

0 comments on commit 2821988

Please sign in to comment.