Skip to content

Commit

Permalink
small style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Oct 16, 2023
1 parent 1e66dc9 commit 83d09b9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
20 changes: 13 additions & 7 deletions R/learner_BART_surv_bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@
#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}.
#'
#' @section Custom mlr3 parameters:
#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is initialized to `TRUE`.
#' - `importance` allows to choose the type of importance. Default is `count`, see documentation of method `$importance()` for more details.
#' - `which.curve` allows to choose which posterior draw will be used for the calculation of the `crank` prediction. If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the curve index, can also be 'mean'. By default the **median posterior** is used, i.e. `which.curve` is 0.5.
#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is
#' initialized to `TRUE`.
#' - `importance` allows to choose the type of importance. Default is `count`,
#' see documentation of method `$importance()` for more details.
#' - `which.curve` allows to choose which posterior draw will be used for the
#' calculation of the `crank` prediction. If between (0,1) it is taken as the
#' quantile of the curves otherwise if greater than 1 it is taken as the curve
#' index, can also be 'mean'. By default the **median posterior** is used,
#' i.e. `which.curve` is 0.5.
#'
#' @templateVar id surv.bart
#' @template learner
Expand Down Expand Up @@ -127,8 +133,8 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART",

x.train = as.data.frame(task$data(cols = task$feature_names))
truth = task$truth()
times = truth[,1]
delta = truth[,2] # delta => status
times = truth[, 1]
delta = truth[, 2] # delta => status

list(
model = invoke(
Expand Down Expand Up @@ -196,14 +202,14 @@ LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART",

# Convert full posterior survival matrix to 3D survival array
# See page 34-35 in Sparapani (2021) for more details
surv.array = aperm(
surv_array = aperm(
array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)),
c(3, 2, 1)
)

# distr => 3d survival array
# crank => expected mortality
mlr3proba::.surv_return(times = times, surv = surv.array,
mlr3proba::.surv_return(times = times, surv = surv_array,
which.curve = pars$which.curve)
}
)
Expand Down
12 changes: 9 additions & 3 deletions man/mlr_learners_surv.bart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 83d09b9

Please sign in to comment.