Skip to content

Commit

Permalink
operation -> adjustment (closes #19)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jun 4, 2024
1 parent e4fc3c9 commit 2c3070c
Show file tree
Hide file tree
Showing 18 changed files with 102 additions and 102 deletions.
8 changes: 4 additions & 4 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10)
}

op <-
new_operation(
adj <-
new_adjustment(
"equivocal_zone",
inputs = "probability",
outputs = "class",
Expand All @@ -45,7 +45,7 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {

new_tailor(
type = x$type,
operations = c(x$operations, list(op)),
adjustments = c(x$adjustments, list(adj)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
Expand All @@ -70,7 +70,7 @@ print.equivocal_zone <- function(x, ...) {

#' @export
fit.equivocal_zone <- function(object, data, tailor = NULL, ...) {
new_operation(
new_adjustment(
class(object),
inputs = object$inputs,
outputs = object$outputs,
Expand Down
8 changes: 4 additions & 4 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ adjust_numeric_calibration <- function(x, method = NULL) {
)
}

op <-
new_operation(
adj <-
new_adjustment(
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
Expand All @@ -51,7 +51,7 @@ adjust_numeric_calibration <- function(x, method = NULL) {

new_tailor(
type = x$type,
operations = c(x$operations, list(op)),
adjustments = c(x$adjustments, list(adj)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
Expand Down Expand Up @@ -81,7 +81,7 @@ fit.numeric_calibration <- function(object, data, tailor = NULL, ...) {
)
)

new_operation(
new_adjustment(
class(object),
inputs = object$inputs,
outputs = object$outputs,
Expand Down
8 changes: 4 additions & 4 deletions R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) {
# remaining input checks are done via probably::bound_prediction
check_tailor(x)

op <-
new_operation(
adj <-
new_adjustment(
"numeric_range",
inputs = "numeric",
outputs = "numeric",
Expand All @@ -21,7 +21,7 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) {

new_tailor(
type = x$type,
operations = c(x$operations, list(op)),
adjustments = c(x$adjustments, list(adj)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
Expand Down Expand Up @@ -59,7 +59,7 @@ print.numeric_range <- function(x, ...) {

#' @export
fit.numeric_range <- function(object, data, tailor = NULL, ...) {
new_operation(
new_adjustment(
class(object),
inputs = object$inputs,
outputs = object$outputs,
Expand Down
8 changes: 4 additions & 4 deletions R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) {
check_tailor(x)
cmds <- enquos(...)

op <-
new_operation(
adj <-
new_adjustment(
"predictions_custom",
inputs = "everything",
outputs = "everything",
Expand All @@ -43,7 +43,7 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) {

new_tailor(
type = x$type,
operations = c(x$operations, list(op)),
adjustments = c(x$adjustments, list(adj)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
Expand All @@ -59,7 +59,7 @@ print.predictions_custom <- function(x, ...) {

#' @export
fit.predictions_custom <- function(object, data, tailor = NULL, ...) {
new_operation(
new_adjustment(
class(object),
inputs = object$inputs,
outputs = object$outputs,
Expand Down
8 changes: 4 additions & 4 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ adjust_probability_calibration <- function(x, method = NULL) {
)
}

op <-
new_operation(
adj <-
new_adjustment(
"probability_calibration",
inputs = "probability",
outputs = "probability_class",
Expand All @@ -30,7 +30,7 @@ adjust_probability_calibration <- function(x, method = NULL) {

new_tailor(
type = x$type,
operations = c(x$operations, list(op)),
adjustments = c(x$adjustments, list(adj)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
Expand Down Expand Up @@ -62,7 +62,7 @@ fit.probability_calibration <- function(object, data, tailor = NULL, ...) {
)
)

new_operation(
new_adjustment(
class(object),
inputs = object$inputs,
outputs = object$outputs,
Expand Down
8 changes: 4 additions & 4 deletions R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ adjust_probability_threshold <- function(x, threshold = 0.5) {
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10)
}

op <-
new_operation(
adj <-
new_adjustment(
"probability_threshold",
inputs = "probability",
outputs = "class",
Expand All @@ -41,7 +41,7 @@ adjust_probability_threshold <- function(x, threshold = 0.5) {

new_tailor(
type = x$type,
operations = c(x$operations, list(op)),
adjustments = c(x$adjustments, list(adj)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
Expand All @@ -66,7 +66,7 @@ print.probability_threshold <- function(x, ...) {

#' @export
fit.probability_threshold <- function(object, data, tailor = NULL, ...) {
new_operation(
new_adjustment(
class(object),
inputs = object$inputs,
outputs = object$outputs,
Expand Down
48 changes: 24 additions & 24 deletions R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,33 @@ tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,

new_tailor(
type,
operations = list(),
adjustments = list(),
columns = columns,
ptype = tibble::new_tibble(list()),
call = current_env()
)
}

new_tailor <- function(type, operations, columns, ptype, call) {
new_tailor <- function(type, adjustments, columns, ptype, call) {
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))

if (!is.list(operations)) {
cli_abort("The {.arg operations} argument should be a list.", call = call)
if (!is.list(adjustments)) {
cli_abort("The {.arg adjustments} argument should be a list.", call = call)
}

is_oper <- purrr::map_lgl(operations, ~ inherits(.x, "operation"))
if (length(is_oper) > 0 && !any(is_oper)) {
bad_oper <- names(is_oper)[!is_oper]
cli_abort("The following {.arg operations} do not have the class \\
{.val operation}: {bad_oper}.", call = call)
is_adjustment <- purrr::map_lgl(adjustments, ~ inherits(.x, "adjustment"))
if (length(is_adjustment) > 0 && !any(is_adjustment)) {
bad_adjustment <- names(is_adjustment)[!is_adjustment]
cli_abort("The following {.arg adjustments} do not have the class \\
{.val adjustment}: {bad_adjustment}.", call = call)
}

# validate operation order and check duplicates
validate_order(operations, type, call)
# validate adjustment order and check duplicates
validate_order(adjustments, type, call)

# check columns
res <- list(
type = type, operations = operations,
type = type, adjustments = adjustments,
columns = columns, ptype = ptype
)
class(res) <- "tailor"
Expand All @@ -64,15 +64,15 @@ new_tailor <- function(type, operations, columns, ptype, call) {
print.tailor <- function(x, ...) {
cli::cli_h1("tailor")

num_op <- length(x$operations)
num_adj <- length(x$adjustments)
cli::cli_text(
"A {ifelse(x$type == 'unknown', '', x$type)} postprocessor \\
with {num_op} operation{?s}{cli::qty(num_op+1)}{?./:}"
with {num_adj} adjustment{?s}{cli::qty(num_adj+1)}{?./:}"
)

if (num_op > 0) {
if (num_adj > 0) {
cli::cli_text("\n")
res <- purrr::map(x$operations, print)
res <- purrr::map(x$adjustments, print)
}

invisible(x)
Expand Down Expand Up @@ -112,16 +112,16 @@ fit.tailor <- function(object, .data, outcome, estimate, probabilities = c(),

object <- new_tailor(
object$type,
operations = object$operations,
adjustments = object$adjustments,
columns = columns,
ptype = ptype,
call = current_env()
)

num_oper <- length(object$operations)
for (op in seq_len(num_oper)) {
object$operations[[op]] <- fit(object$operations[[op]], .data, object)
.data <- predict(object$operations[[op]], .data, object)
num_adjustment <- length(object$adjustments)
for (adj in seq_len(num_adjustment)) {
object$adjustments[[adj]] <- fit(object$adjustments[[adj]], .data, object)
.data <- predict(object$adjustments[[adj]], .data, object)
}

# todo Add a fitted tailor class?
Expand All @@ -131,9 +131,9 @@ fit.tailor <- function(object, .data, outcome, estimate, probabilities = c(),
#' @export
predict.tailor <- function(object, new_data, ...) {
# validate levels/classes
num_oper <- length(object$operations)
for (op in seq_len(num_oper)) {
new_data <- predict(object$operations[[op]], new_data, object)
num_adjustment <- length(object$adjustments)
for (adj in seq_len(num_adjustment)) {
new_data <- predict(object$adjustments[[adj]], new_data, object)
}
if (!tibble::is_tibble(new_data)) {
new_data <- tibble::as_tibble(new_data)
Expand Down
18 changes: 9 additions & 9 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ is_tune <- function(x) {
isTRUE(identical(quote(tune), x[[1]]))
}

# for operations with no tunable parameters
# for adjustments with no tunable parameters

no_param <-
tibble::tibble(
Expand All @@ -33,7 +33,7 @@ no_param <-
input_vals <- c("numeric", "probability", "class", "everything")
output_vals <- c("numeric", "probability_class", "class", "everything")

new_operation <- function(cls, inputs, outputs, arguments, results = list(),
new_adjustment <- function(cls, inputs, outputs, arguments, results = list(),
trained, requires_fit, ...) {
inputs <- arg_match0(inputs, input_vals)
outputs <- arg_match0(outputs, output_vals)
Expand All @@ -49,7 +49,7 @@ new_operation <- function(cls, inputs, outputs, arguments, results = list(),
trained = trained,
requires_fit = requires_fit
)
class(res) <- c(cls, "operation")
class(res) <- c(cls, "adjustment")
res
}

Expand All @@ -62,25 +62,25 @@ is_tailor <- function(x) {
#' @keywords internal
#' @rdname tailor-internals
tailor_fully_trained <- function(x) {
if (length(x$operations) == 0L) {
if (length(x$adjustments) == 0L) {
return(FALSE)
}

all(purrr::map_lgl(x$operations, tailor_operation_trained))
all(purrr::map_lgl(x$adjustments, tailor_adjustment_trained))
}

tailor_operation_trained <- function(x) {
tailor_adjustment_trained <- function(x) {
isTRUE(x$trained)
}

#' @export
#' @keywords internal
#' @rdname tailor-internals
tailor_requires_fit <- function(x) {
any(purrr::map_lgl(x$operations, tailor_operation_requires_fit))
any(purrr::map_lgl(x$adjustments, tailor_adjustment_requires_fit))
}

tailor_operation_requires_fit <- function(x) {
tailor_adjustment_requires_fit <- function(x) {
isTRUE(x$requires_fit)
}

Expand Down Expand Up @@ -114,7 +114,7 @@ check_calibration_type <- function(calibration_type, calibration_type_expected,
tailor_type, call) {
if (!identical(calibration_type, calibration_type_expected)) {
cli_abort(
"A {.field {tailor_type}} tailor is incompatible with the operation \\
"A {.field {tailor_type}} tailor is incompatible with the adjustment \\
{.fun {paste0('adjust_', calibration_type, '_calibration')}}.",
call = call
)
Expand Down
Loading

0 comments on commit 2c3070c

Please sign in to comment.