-
-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update doc links * add BART survival learner * fix dnnsurv parameter test * add BART to 'Suggests' * various small fixes * add more libraries to Suggests to run new BART example * update doc * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * change K parameter type * simplify and speed-up the creation of the survival matrix * add importance parameter, remove factor feature type * change tag for importance to train + fix small bug * update doc * fix tests * change section name * remove BART example and extra libraries * return model list slot and name refactoring * update doc * store full posterior survival array (testing version) * will work with latest version of distr6 * update mlr3proba to 0.5.3 + refactoring * add which.curve parameter, defaults to 0.5 (median posterior) * update doc * update BART test * remove code after checks (distr6 converts survival array correctly) * better constraction of 'which.curve' parameter * fix bug (which.curve was always NULL) * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * Update R/learner_BART_surv_bart.R Co-authored-by: Sebastian Fischer <[email protected]> * remove new parameter (to be corrected in another PR) * changes after code review * remove delayedAssign + add more review suggestions * explain better 'varcount.mean' * small update of BART doc * add more doc for 'which.curve' * small style changes * fix hanging indent * add no lint --------- Co-authored-by: Sebastian Fischer <[email protected]>
- Loading branch information
1 parent
3855302
commit 2dbce82
Showing
8 changed files
with
491 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
#' @title Survival Bayesian Additive Regression Trees Learner | ||
#' @author bblodfon | ||
#' @name mlr_learners_surv.bart | ||
#' | ||
#' @description | ||
#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored | ||
#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}. | ||
#' | ||
#' @details | ||
#' Two types of prediction are returned for this learner: | ||
#' 1. `distr`: a 3d survival array with observations as 1st dimension, time | ||
#' points as 2nd and the posterior draws as 3rd dimension. | ||
#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return]. The parameter | ||
#' `which.curve` decides which posterior draw (3rd dimension) will be used for the | ||
#' calculation of the expected mortality. Note that the median posterior is | ||
#' by default used for the calculation of survival measures that require a `distr` | ||
#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv]. | ||
#' | ||
#' @section Custom mlr3 defaults: | ||
#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. | ||
#' | ||
#' @section Custom mlr3 parameters: | ||
#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is | ||
#' initialized to `TRUE`. | ||
#' - `importance` allows to choose the type of importance. Default is `count`, | ||
#' see documentation of method `$importance()` for more details. | ||
#' - `which.curve` allows to choose which posterior draw will be used for the | ||
#' calculation of the `crank` prediction. If between (0,1) it is taken as the | ||
#' quantile of the curves otherwise if greater than 1 it is taken as the curve | ||
#' index, can also be 'mean'. By default the **median posterior** is used, | ||
#' i.e. `which.curve` is 0.5. | ||
#' | ||
#' @templateVar id surv.bart | ||
#' @template learner | ||
#' | ||
#' @references | ||
#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")` | ||
#' | ||
#' @template seealso_learner | ||
#' @template example | ||
#' @export | ||
LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", | ||
inherit = mlr3proba::LearnerSurv, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function() { | ||
param_set = ps( | ||
K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")), | ||
events = p_uty(default = NULL, tags = c("train", "predict")), | ||
ztimes = p_uty(default = NULL, tags = c("train", "predict")), | ||
zdelta = p_uty(default = NULL, tags = c("train", "predict")), | ||
sparse = p_lgl(default = FALSE, tags = "train"), | ||
theta = p_dbl(default = 0, tags = "train"), | ||
omega = p_dbl(default = 1, tags = "train"), | ||
a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"), | ||
b = p_dbl(default = 1L, tags = "train"), | ||
augment = p_lgl(default = FALSE, tags = "train"), | ||
rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), | ||
usequants = p_lgl(default = FALSE, tags = "train"), | ||
rm.const = p_lgl(default = TRUE, tags = "train"), | ||
type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"), | ||
ntype = p_int(lower = 1, upper = 3, tags = "train"), | ||
k = p_dbl(default = 2.0, lower = 0, tags = "train"), | ||
power = p_dbl(default = 2.0, lower = 0, tags = "train"), | ||
base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"), | ||
offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"), | ||
ntree = p_int(default = 50L, lower = 1L, tags = "train"), | ||
numcut = p_int(default = 100L, lower = 1L, tags = "train"), | ||
ndpost = p_int(default = 1000L, lower = 1L, tags = "train"), | ||
nskip = p_int(default = 250L, lower = 0L, tags = "train"), | ||
keepevery = p_int(default = 10L, lower = 1L, tags = "train"), | ||
printevery = p_int(default = 100L, lower = 1L, tags = "train"), | ||
seed = p_int(default = 99L, tags = "train"), | ||
mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")), | ||
nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")), | ||
openmp = p_lgl(default = TRUE, tags = "predict"), | ||
quiet = p_lgl(default = TRUE, tags = "predict"), | ||
importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"), | ||
which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict") | ||
) | ||
|
||
# custom defaults | ||
param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count", | ||
which.curve = 0.5) # 0.5 quantile => median posterior | ||
|
||
super$initialize( | ||
id = "surv.bart", | ||
packages = "BART", | ||
feature_types = c("logical", "integer", "numeric"), | ||
predict_types = c("crank", "distr"), | ||
param_set = param_set, | ||
properties = c("importance", "missings"), | ||
man = "mlr3extralearners::mlr_learners_surv.bart", | ||
label = "Bayesian Additive Regression Trees" | ||
) | ||
}, | ||
|
||
#' @description | ||
#' Two types of importance scores are supported based on the value | ||
#' of the parameter `importance`: | ||
#' 1. `prob`: The mean selection probability of each feature in the trees, | ||
#' extracted from the slot `varprob.mean`. | ||
#' If `sparse = FALSE` (default), this is a fixed constant. | ||
#' Recommended to use this option when `sparse = TRUE`. | ||
#' 2. `count`: The mean observed count of each feature in the trees (average | ||
#' number of times the feature was used in a tree decision rule across all | ||
#' posterior draws), extracted from the slot `varcount.mean`. | ||
#' This is the default importance scores. | ||
#' | ||
#' In both cases, higher values signify more important variables. | ||
#' | ||
#' @return Named `numeric()`. | ||
importance = function() { | ||
if (is.null(self$model$model)) { | ||
stopf("No model stored") | ||
} | ||
|
||
pars = self$param_set$get_values(tags = "train") | ||
|
||
if (pars$importance == "prob") { | ||
sort(self$model$model$varprob.mean[-1], decreasing = TRUE) | ||
} else { | ||
sort(self$model$model$varcount.mean[-1], decreasing = TRUE) | ||
} | ||
} | ||
), | ||
|
||
private = list( | ||
.train = function(task) { | ||
pars = self$param_set$get_values(tags = "train") | ||
pars$importance = NULL # not used in the train function | ||
|
||
x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint | ||
truth = task$truth() | ||
times = truth[, 1] | ||
delta = truth[, 2] # delta => status | ||
|
||
list( | ||
model = invoke( | ||
BART::mc.surv.bart, | ||
x.train = x.train, | ||
times = times, | ||
delta = delta, | ||
.args = pars | ||
), | ||
# need these for predict | ||
x.train = x.train, | ||
times = times, | ||
delta = delta | ||
) | ||
}, | ||
|
||
.predict = function(task) { | ||
# get parameters with tag "predict" | ||
pars = self$param_set$get_values(tags = "predict") | ||
|
||
# get newdata and ensure same ordering in train and predict | ||
x.test = as.data.frame(ordered_features(task, self)) # nolint | ||
|
||
# subset parameters to use in `surv.pre.bart` | ||
pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")] | ||
|
||
# transform data to be suitable for BART survival analysis (needs train data) | ||
trans_data = invoke( | ||
BART::surv.pre.bart, | ||
times = self$model$times, | ||
delta = self$model$delta, | ||
x.train = self$model$x.train, | ||
x.test = x.test, | ||
.args = pars_pre | ||
) | ||
|
||
# subset parameters to use in `predict` | ||
pars_pred = pars[names(pars) %in% c("mc.cores", "nice")] | ||
|
||
pred_fun = function() { | ||
invoke( | ||
predict, | ||
self$model$model, | ||
newdata = trans_data$tx.test, | ||
.args = pars_pred | ||
) | ||
} | ||
|
||
# don't print C++ generated info during prediction | ||
if (pars$quiet) { | ||
utils::capture.output({ | ||
pred = pred_fun() | ||
}) | ||
} else { | ||
pred = pred_fun() | ||
} | ||
|
||
# Number of test observations | ||
N = task$nrow | ||
# Number of unique times | ||
K = pred$K | ||
times = pred$times | ||
# Number of posterior draws | ||
M = nrow(pred$surv.test) | ||
|
||
# Convert full posterior survival matrix to 3D survival array | ||
# See page 34-35 in Sparapani (2021) for more details | ||
surv_array = aperm( | ||
array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), | ||
c(3, 2, 1) | ||
) | ||
|
||
# distr => 3d survival array | ||
# crank => expected mortality | ||
mlr3proba::.surv_return(times = times, surv = surv_array, | ||
which.curve = pars$which.curve) | ||
} | ||
) | ||
) | ||
|
||
.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART) |
Oops, something went wrong.