diff --git a/R/learner_BART_surv_bart.R b/R/learner_BART_surv_bart.R index 376d44b5b..4ec31f464 100644 --- a/R/learner_BART_surv_bart.R +++ b/R/learner_BART_surv_bart.R @@ -20,9 +20,15 @@ #' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}. #' #' @section Custom mlr3 parameters: -#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is initialized to `TRUE`. -#' - `importance` allows to choose the type of importance. Default is `count`, see documentation of method `$importance()` for more details. -#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the **median posterior** is used, i.e. `which.curve` is 0.5. +#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is +#' initialized to `TRUE`. +#' - `importance` allows to choose the type of importance. Default is `count`, +#' see documentation of method `$importance()` for more details. +#' - `which.curve` allows to choose which posterior draw will be used for the +#' calculation of the `crank` prediction. If between (0,1) it is taken as the +#' quantile of the curves otherwise if greater than 1 it is taken as the curve +#' index, can also be 'mean'. By default the **median posterior** is used, +#' i.e. `which.curve` is 0.5. #' #' @templateVar id surv.bart #' @template learner @@ -127,8 +133,8 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", x.train = as.data.frame(task$data(cols = task$feature_names)) truth = task$truth() - times = truth[,1] - delta = truth[,2] # delta => status + times = truth[, 1] + delta = truth[, 2] # delta => status list( model = invoke( @@ -196,14 +202,14 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART", # Convert full posterior survival matrix to 3D survival array # See page 34-35 in Sparapani (2021) for more details - surv.array = aperm( + surv_array = aperm( array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)), c(3, 2, 1) ) # distr => 3d survival array # crank => expected mortality - mlr3proba::.surv_return(times = times, surv = surv.array, + mlr3proba::.surv_return(times = times, surv = surv_array, which.curve = pars$which.curve) } ) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index 3045803f6..3c72ce758 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -30,9 +30,15 @@ prediction, see more info on \link[mlr3proba:PredictionSurv]{PredictionSurv}. \section{Custom mlr3 parameters}{ \itemize{ -\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is initialized to \code{TRUE}. -\item \code{importance} allows to choose the type of importance. Default is \code{count}, see documentation of method \verb{$importance()} for more details. -\item \code{which.curve} allows to choose which posterior draw will be used for the calculation of the \code{crank} prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the \strong{median posterior} is used, i.e. \code{which.curve} is 0.5. +\item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is +initialized to \code{TRUE}. +\item \code{importance} allows to choose the type of importance. Default is \code{count}, +see documentation of method \verb{$importance()} for more details. +\item \code{which.curve} allows to choose which posterior draw will be used for the +calculation of the \code{crank} prediction. If between (0,1) it is taken as the +quantile of the curves otherwise if greater than 1 it is taken as the curve +index, can also be 'mean'. By default the \strong{median posterior} is used, +i.e. \code{which.curve} is 0.5. } }