Skip to content

Commit

Permalink
add BART survival learner (#290)
Browse files Browse the repository at this point in the history
* update doc links

* add BART survival learner

* fix dnnsurv parameter test

* add BART to 'Suggests'

* various small fixes

* add more libraries to Suggests to run new BART example

* update doc

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* change K parameter type

* simplify and speed-up the creation of the survival matrix

* add importance parameter, remove factor feature type

* change tag for importance to train + fix small bug

* update doc

* fix tests

* change section name

* remove BART example and extra libraries

* return model list slot and name refactoring

* update doc

* store full posterior survival array (testing version)

* will work with latest version of distr6

* update mlr3proba to 0.5.3 + refactoring

* add which.curve parameter, defaults to 0.5 (median posterior)

* update doc

* update BART test

* remove code after checks (distr6 converts survival array correctly)

* better constraction of 'which.curve' parameter

* fix bug (which.curve was always NULL)

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/learner_BART_surv_bart.R

Co-authored-by: Sebastian Fischer <[email protected]>

* remove new parameter (to be corrected in another PR)

* changes after code review

* remove delayedAssign + add more review suggestions

* explain better 'varcount.mean'

* small update of BART doc

* add more doc for 'which.curve'

* small style changes

* fix hanging indent

* add no lint

---------

Co-authored-by: Sebastian Fischer <[email protected]>
  • Loading branch information
bblodfon and sebffischer authored Oct 17, 2023
1 parent 3855302 commit 2dbce82
Show file tree
Hide file tree
Showing 8 changed files with 491 additions and 3 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Suggests:
aorsf (>= 0.0.5),
actuar,
apcluster,
BART (>= 2.9.4),
C50,
coin,
CoxBoost,
Expand Down Expand Up @@ -66,7 +67,7 @@ Suggests:
mgcv,
mlr3cluster,
mlr3learners (>= 0.4.2),
mlr3proba,
mlr3proba (>= 0.5.3),
mlr3pipelines,
mvtnorm,
nnet,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ export(LearnerSurvGAMBoost)
export(LearnerSurvGBM)
export(LearnerSurvGLMBoost)
export(LearnerSurvGlmnet)
export(LearnerSurvLearnerSurvBART)
export(LearnerSurvLogisticHazard)
export(LearnerSurvMBoost)
export(LearnerSurvNelson)
Expand Down
218 changes: 218 additions & 0 deletions R/learner_BART_surv_bart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#' @title Survival Bayesian Additive Regression Trees Learner
#' @author bblodfon
#' @name mlr_learners_surv.bart
#'
#' @description
#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored
#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}.
#'
#' @details
#' Two types of prediction are returned for this learner:
#' 1. `distr`: a 3d survival array with observations as 1st dimension, time
#' points as 2nd and the posterior draws as 3rd dimension.
#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return]. The parameter
#' `which.curve` decides which posterior draw (3rd dimension) will be used for the
#' calculation of the expected mortality. Note that the median posterior is
#' by default used for the calculation of survival measures that require a `distr`
#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv].
#'
#' @section Custom mlr3 defaults:
#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}.
#'
#' @section Custom mlr3 parameters:
#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is
#' initialized to `TRUE`.
#' - `importance` allows to choose the type of importance. Default is `count`,
#' see documentation of method `$importance()` for more details.
#' - `which.curve` allows to choose which posterior draw will be used for the
#' calculation of the `crank` prediction. If between (0,1) it is taken as the
#' quantile of the curves otherwise if greater than 1 it is taken as the curve
#' index, can also be 'mean'. By default the **median posterior** is used,
#' i.e. `which.curve` is 0.5.
#'
#' @templateVar id surv.bart
#' @template learner
#'
#' @references
#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")`
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART",
inherit = mlr3proba::LearnerSurv,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")),
events = p_uty(default = NULL, tags = c("train", "predict")),
ztimes = p_uty(default = NULL, tags = c("train", "predict")),
zdelta = p_uty(default = NULL, tags = c("train", "predict")),
sparse = p_lgl(default = FALSE, tags = "train"),
theta = p_dbl(default = 0, tags = "train"),
omega = p_dbl(default = 1, tags = "train"),
a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"),
b = p_dbl(default = 1L, tags = "train"),
augment = p_lgl(default = FALSE, tags = "train"),
rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
usequants = p_lgl(default = FALSE, tags = "train"),
rm.const = p_lgl(default = TRUE, tags = "train"),
type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"),
ntype = p_int(lower = 1, upper = 3, tags = "train"),
k = p_dbl(default = 2.0, lower = 0, tags = "train"),
power = p_dbl(default = 2.0, lower = 0, tags = "train"),
base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"),
offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
ntree = p_int(default = 50L, lower = 1L, tags = "train"),
numcut = p_int(default = 100L, lower = 1L, tags = "train"),
ndpost = p_int(default = 1000L, lower = 1L, tags = "train"),
nskip = p_int(default = 250L, lower = 0L, tags = "train"),
keepevery = p_int(default = 10L, lower = 1L, tags = "train"),
printevery = p_int(default = 100L, lower = 1L, tags = "train"),
seed = p_int(default = 99L, tags = "train"),
mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")),
nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")),
openmp = p_lgl(default = TRUE, tags = "predict"),
quiet = p_lgl(default = TRUE, tags = "predict"),
importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"),
which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict")
)

# custom defaults
param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count",
which.curve = 0.5) # 0.5 quantile => median posterior

super$initialize(
id = "surv.bart",
packages = "BART",
feature_types = c("logical", "integer", "numeric"),
predict_types = c("crank", "distr"),
param_set = param_set,
properties = c("importance", "missings"),
man = "mlr3extralearners::mlr_learners_surv.bart",
label = "Bayesian Additive Regression Trees"
)
},

#' @description
#' Two types of importance scores are supported based on the value
#' of the parameter `importance`:
#' 1. `prob`: The mean selection probability of each feature in the trees,
#' extracted from the slot `varprob.mean`.
#' If `sparse = FALSE` (default), this is a fixed constant.
#' Recommended to use this option when `sparse = TRUE`.
#' 2. `count`: The mean observed count of each feature in the trees (average
#' number of times the feature was used in a tree decision rule across all
#' posterior draws), extracted from the slot `varcount.mean`.
#' This is the default importance scores.
#'
#' In both cases, higher values signify more important variables.
#'
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model$model)) {
stopf("No model stored")
}

pars = self$param_set$get_values(tags = "train")

if (pars$importance == "prob") {
sort(self$model$model$varprob.mean[-1], decreasing = TRUE)
} else {
sort(self$model$model$varcount.mean[-1], decreasing = TRUE)
}
}
),

private = list(
.train = function(task) {
pars = self$param_set$get_values(tags = "train")
pars$importance = NULL # not used in the train function

x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint
truth = task$truth()
times = truth[, 1]
delta = truth[, 2] # delta => status

list(
model = invoke(
BART::mc.surv.bart,
x.train = x.train,
times = times,
delta = delta,
.args = pars
),
# need these for predict
x.train = x.train,
times = times,
delta = delta
)
},

.predict = function(task) {
# get parameters with tag "predict"
pars = self$param_set$get_values(tags = "predict")

# get newdata and ensure same ordering in train and predict
x.test = as.data.frame(ordered_features(task, self)) # nolint

# subset parameters to use in `surv.pre.bart`
pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")]

# transform data to be suitable for BART survival analysis (needs train data)
trans_data = invoke(
BART::surv.pre.bart,
times = self$model$times,
delta = self$model$delta,
x.train = self$model$x.train,
x.test = x.test,
.args = pars_pre
)

# subset parameters to use in `predict`
pars_pred = pars[names(pars) %in% c("mc.cores", "nice")]

pred_fun = function() {
invoke(
predict,
self$model$model,
newdata = trans_data$tx.test,
.args = pars_pred
)
}

# don't print C++ generated info during prediction
if (pars$quiet) {
utils::capture.output({
pred = pred_fun()
})
} else {
pred = pred_fun()
}

# Number of test observations
N = task$nrow
# Number of unique times
K = pred$K
times = pred$times
# Number of posterior draws
M = nrow(pred$surv.test)

# Convert full posterior survival matrix to 3D survival array
# See page 34-35 in Sparapani (2021) for more details
surv_array = aperm(
array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)),
c(3, 2, 1)
)

# distr => 3d survival array
# crank => expected mortality
mlr3proba::.surv_return(times = times, surv = surv_array,
which.curve = pars$which.curve)
}
)
)

.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART)
Loading

0 comments on commit 2dbce82

Please sign in to comment.