Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add helper for bridging causal fits #199

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: workflows
Title: Modeling Workflows
Version: 1.1.3.9000
Version: 1.1.3.9001
Authors@R: c(
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
person("Simon", "Couch", , "[email protected]", role = c("aut", "cre"),
Expand All @@ -24,7 +24,7 @@ Imports:
hardhat (>= 1.2.0),
lifecycle (>= 1.0.3),
modelenv (>= 0.1.0),
parsnip (>= 1.0.3),
parsnip (>= 1.1.0.9001),
rlang (>= 1.0.3),
tidyselect (>= 1.2.0),
vctrs (>= 0.4.1)
Expand All @@ -38,6 +38,8 @@ Suggests:
recipes (>= 1.0.0),
rmarkdown,
testthat (>= 3.0.0)
Remotes:
tidymodels/parsnip#955
VignetteBuilder:
knitr
Config/Needs/website:
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ S3method(print,workflow)
S3method(tidy,workflow)
S3method(tunable,workflow)
S3method(tune_args,workflow)
S3method(weight_propensity,workflow)
export(.fit_finalize)
export(.fit_model)
export(.fit_pre)
Expand Down Expand Up @@ -69,4 +70,5 @@ importFrom(hardhat,extract_recipe)
importFrom(hardhat,extract_spec_parsnip)
importFrom(lifecycle,deprecated)
importFrom(parsnip,fit_xy)
importFrom(parsnip,weight_propensity)
importFrom(stats,predict)
43 changes: 43 additions & 0 deletions R/weight_propensity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#' Helper for bridging two-stage causal fits
#'
#' @inherit parsnip::weight_propensity.model_fit description
#'
#' @inheritParams parsnip::weight_propensity.model_fit
#'
#' @inherit parsnip::weight_propensity.model_fit return
#'
#' @inherit parsnip::weight_propensity.model_fit references
#'
#' @importFrom parsnip weight_propensity
#' @method weight_propensity workflow
#' @export
weight_propensity.workflow <- function(object,
wt_fn,
.treated = extract_fit_parsnip(object)$lvl[2],
...,
data) {
if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) {
abort("`wt_fn` must be a function.")
}

if (rlang::is_missing(data) || !is.data.frame(data)) {
abort("`data` must be the data supplied as the data argument to `fit()`.")
}

if (!is_trained_workflow(object)) {
abort("`weight_propensity()` is not well-defined for an unfitted workflow.")
}

outcome_name <- names(object$pre$mold$outcomes)

preds <- predict(object, data, type = "prob")
preds <- preds[[paste0(".pred_", .treated)]]

data$.wts <-
hardhat::importance_weights(
wt_fn(preds, data[[outcome_name]], .treated = .treated, ...)
)

data
}

53 changes: 53 additions & 0 deletions man/weight_propensity.workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions tests/testthat/_snaps/weight_propensity.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# errors informatively with bad input

Code
weight_propensity(wf, silly_wt_fn, data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `weight_propensity()` is not well-defined for an unfitted workflow.

---

Code
weight_propensity(wf_fit, data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `wt_fn` must be a function.

---

Code
weight_propensity(wf_fit, "boop", data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `wt_fn` must be a function.

---

Code
weight_propensity(wf_fit, function(...) {
-1L
}, data = two_class_dat)
Condition
Error in `hardhat::importance_weights()`:
! `x` can't contain negative weights.

---

Code
weight_propensity(wf_fit, silly_wt_fn)
Condition
Error in `weight_propensity()`:
! `data` must be the data supplied as the data argument to `fit()`.

---

Code
weight_propensity(wf_fit, silly_wt_fn, data = "boop")
Condition
Error in `weight_propensity()`:
! `data` must be the data supplied as the data argument to `fit()`.

63 changes: 63 additions & 0 deletions tests/testthat/test-weight_propensity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
test_that("basic functionality", {
skip_if_not_installed("modeldata")
library(modeldata)
library(parsnip)

silly_wt_fn <- function(.propensity, .exposure = NULL, ...) {
seq(1, 2, length.out = length(.propensity))
}

lr_fit <- fit(workflow(Class ~ A + B, logistic_reg()), two_class_dat)

lr_res1 <- weight_propensity(lr_fit, silly_wt_fn, data = two_class_dat)
expect_s3_class(lr_res1, "tbl_df")
expect_true(all(names(lr_res1) %in% c(names(two_class_dat), ".wts")))
expect_equal(lr_res1$.wts, importance_weights(seq(1, 2, length.out = nrow(two_class_dat))))
})

test_that("errors informatively with bad input", {
skip_if_not_installed("modeldata")
library(modeldata)
library(parsnip)

silly_wt_fn <- function(.propensity, .exposure = NULL, ...) {
seq(1, 2, length.out = length(.propensity))
}

# untrained workflow
wf <- workflow(Class ~ A + B, logistic_reg())

expect_snapshot(
error = TRUE,
weight_propensity(wf, silly_wt_fn, data = two_class_dat)
)

# bad `wt_fn`
wf_fit <- fit(wf, two_class_dat)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, data = two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, "boop", data = two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, function(...) {-1L}, data = two_class_dat)
)

# bad `data`
expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, silly_wt_fn)
)

expect_snapshot(
error = TRUE,
weight_propensity(wf_fit, silly_wt_fn, data = "boop")
)
})