Skip to content

Commit

Permalink
feature issue #1489
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 16, 2024
1 parent b441060 commit 5bb6531
Show file tree
Hide file tree
Showing 19 changed files with 242 additions and 87 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.21.8
Date: 2024-09-12
Version: 2.21.9
Date: 2024-09-16
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ export(read_csv_as_stanfit)
export(recompile_model)
export(reloo)
export(rename_pars)
export(resp_bhaz)
export(resp_cat)
export(resp_cens)
export(resp_dec)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### New Features

* Support stratified `cox` models via the new addition term `bhaz`. (#1489)
* Support futures for parallelization in the `cmdstanr` backend. (#1684)
* Add method `loo_epred` thanks to Aki Vehtari. (#1641)
* Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354)
Expand Down
18 changes: 12 additions & 6 deletions R/brmsformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -1296,12 +1296,6 @@ validate_formula.brmsformula <- function(
out$family$thres <- extract_thres_names(out, data)
out$family$cats <- extract_cat_names(out, data)
}
if (is.mixfamily(out$family)) {
# every mixture family needs to know about response categories
for (i in seq_along(out$family$mix)) {
out$family$mix[[i]]$thres <- out$family$thres
}
}
}
conv_cats_dpars <- conv_cats_dpars(out$family)
if (conv_cats_dpars && !is.null(data)) {
Expand Down Expand Up @@ -1337,6 +1331,18 @@ validate_formula.brmsformula <- function(
out$family$dpars <- union(dp_dpars, out$family$dpars)
}
}
if (is_cox(out$family) && !is.null(data)) {
# for easy access of baseline hazards
out$family$bhaz <- extract_bhaz(out, data)
}
if (is.mixfamily(out$family)) {
# every mixture family needs to know about additional response information
for (i in seq_along(out$family$mix)) {
for (term in c("cats", "thres", "bhaz")) {
out$family$mix[[i]][[term]] <- out$family[[term]]
}
}
}

# incorporate deprecated arguments
require_threshold <- is_ordinal(out$family) && is.null(out$family$threshold)
Expand Down
4 changes: 2 additions & 2 deletions R/brmsframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ frame_basis_bhaz <- function(x, data, ...) {
if (is_cox(x$family)) {
# compute basis matrix of the baseline hazard for the Cox model
y <- model.response(model.frame(x$respform, data, na.action = na.pass))
out$basis_matrix <- bhaz_basis_matrix(y, args = x$family$bhaz)
args <- family_info(x, "bhaz")$args
out$basis_matrix <- bhaz_basis_matrix(y, args = args)
}
out
}

61 changes: 50 additions & 11 deletions R/data-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,34 @@ data_bhaz <- function(bframe, data, data2, prior) {
return(out)
}
y <- bframe$frame$resp$values
args <- bframe$family$bhaz
bhaz <- family_info(bframe, "bhaz")
bs <- bframe$basis$bhaz$basis_matrix
out$Zbhaz <- bhaz_basis_matrix(y, args, basis = bs)
out$Zcbhaz <- bhaz_basis_matrix(y, args, integrate = TRUE, basis = bs)
out$Zbhaz <- bhaz_basis_matrix(y, bhaz$args, basis = bs)
out$Zcbhaz <- bhaz_basis_matrix(y, bhaz$args, integrate = TRUE, basis = bs)
out$Kbhaz <- NCOL(out$Zbhaz)
sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp)
con_sbhaz <- eval_dirichlet(sbhaz_prior$prior, out$Kbhaz, data2)
out$con_sbhaz <- as.array(con_sbhaz)
groups <- bhaz$groups
if (!is.null(groups)) {
out$ngrbhaz <- length(groups)
gr <- get_ad_values(bframe, "bhaz", "gr", data)
gr <- factor(rename(gr), levels = groups)
out$Jgrbhaz <- match(gr, groups)
out$con_sbhaz <- matrix(nrow = out$ngrbhaz, ncol = out$Kbhaz)
sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp)
sbhaz_prior_global <- subset2(sbhaz_prior, group = "")
con_sbhaz_global <- eval_dirichlet(sbhaz_prior_global$prior, out$Kbhaz, data2)
for (k in seq_along(groups)) {
sbhaz_prior_group <- subset2(sbhaz_prior, group = groups[k])
if (nzchar(sbhaz_prior_group$prior)) {
out$con_sbhaz[k, ] <- eval_dirichlet(sbhaz_prior_group$prior, out$Kbhaz, data2)
} else {
out$con_sbhaz[k, ] <- con_sbhaz_global
}
}
} else {
sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp)
con_sbhaz <- eval_dirichlet(sbhaz_prior$prior, out$Kbhaz, data2)
out$con_sbhaz <- as.array(con_sbhaz)
}
out
}

Expand All @@ -502,9 +522,6 @@ bhaz_basis_matrix <- function(y, args = list(), integrate = FALSE,
}
stopifnot(is.list(args))
args$x <- y
if (!is.null(args$intercept)) {
args$intercept <- as_one_logical(args$intercept)
}
if (is.null(args$Boundary.knots)) {
# avoid 'knots' outside 'Boundary.knots' error (#1143)
# we also need a smaller lower boundary knot to avoid lp = -Inf
Expand All @@ -524,6 +541,29 @@ bhaz_basis_matrix <- function(y, args = list(), integrate = FALSE,
out
}

# extract baseline hazard information from data for storage in the model family
# @return a named list with elements:
# args: arguments that can be passed to bhaz_basis_matrix
# groups: optional names of the groups for which to stratify
extract_bhaz <- function(x, data) {
stopifnot(is.brmsformula(x) || is.brmsterms(x), is_cox(x))
if (is.null(x$adforms)) {
x$adforms <- terms_ad(x$formula, x$family)
}
out <- list()
if (is.null(x$adforms$bhaz)) {
# bhaz is an optional addition term so defaults need to be listed here too
out$args <- list(df = 5, intercept = TRUE)
} else {
out$args <- eval_rhs(x$adforms$bhaz)$flags
gr <- get_ad_values(x, "bhaz", "gr", data)
if (!is.null(gr)) {
out$groups <- rename(levels(factor(gr)))
}
}
out
}

# extract names of response categories
# @param x a brmsterms object or one that can be coerced to it
# @param data user specified data
Expand All @@ -550,7 +590,6 @@ extract_cat_names <- function(x, data) {
# @return a data.frame with columns 'thres' and 'group'
extract_thres_names <- function(x, data) {
stopifnot(is.brmsformula(x) || is.brmsterms(x), has_thres(x))

if (is.null(x$adforms)) {
x$adforms <- terms_ad(x$formula, x$family)
}
Expand Down Expand Up @@ -609,7 +648,7 @@ extract_thres_names <- function(x, data) {
data.frame(thres, group, stringsAsFactors = FALSE)
}

# extract threshold names from the response values
# extract number of thresholds from the response values
# @param formula with the response on the LHS
# @param data a data.frame from which to extract responses
# @param extra_cat is the first category an extra (hurdle) category?
Expand Down
45 changes: 19 additions & 26 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
#' category is used as the reference. If \code{NA}, all categories will be
#' predicted, which requires strong priors or carefully specified predictor
#' terms in order to lead to an identified model.
#' @param bhaz Currently for experimental purposes only.
#'
#' @details
#' Below, we list common use cases for the different families.
Expand Down Expand Up @@ -199,7 +198,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
link_alpha = "identity",
link_quantile = "logit",
threshold = "flexible",
refcat = NULL, bhaz = NULL) {
refcat = NULL) {
slink <- substitute(link)
.brmsfamily(
family, link = link, slink = slink,
Expand All @@ -212,8 +211,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
link_ndt = link_ndt, link_bias = link_bias,
link_alpha = link_alpha, link_xi = link_xi,
link_quantile = link_quantile,
threshold = threshold, refcat = refcat,
bhaz = bhaz
threshold = threshold, refcat = refcat
)
}

Expand All @@ -227,7 +225,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
# @return an object of 'brmsfamily' which inherits from 'family'
.brmsfamily <- function(family, link = NULL, slink = link,
threshold = "flexible",
refcat = NULL, bhaz = NULL, ...) {
refcat = NULL, ...) {
family <- tolower(as_one_character(family))
aux_links <- list(...)
pattern <- c("^normal$", "^zi_", "^hu_")
Expand Down Expand Up @@ -300,23 +298,6 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
out$refcat <- as_one_character(refcat, allow_na = allow_na_ref)
}
}
if (is_cox(out$family)) {
if (!is.null(bhaz)) {
if (!is.list(bhaz)) {
stop2("'bhaz' should be a list.")
}
out$bhaz <- bhaz
} else {
out$bhaz <- list()
}
# set default arguments
if (is.null(out$bhaz$df)) {
out$bhaz$df <- 5L
}
if (is.null(out$bhaz$intercept)) {
out$bhaz$intercept <- TRUE
}
}
out
}

Expand Down Expand Up @@ -475,8 +456,8 @@ combine_family_info <- function(x, y, ...) {
clb <- !any(ulapply(x[, 1], isFALSE))
cub <- !any(ulapply(x[, 2], isFALSE))
x <- c(clb, cub)
} else if (y == "thres") {
# thresholds are the same across mixture components
} else if (y %in% c("thres", "bhaz")) {
# same across mixture components
x <- x[[1]]
}
x
Expand Down Expand Up @@ -687,9 +668,9 @@ zero_inflated_asym_laplace <- function(link = "identity", link_sigma = "log",

#' @rdname brmsfamily
#' @export
cox <- function(link = "log", bhaz = NULL) {
cox <- function(link = "log") {
slink <- substitute(link)
.brmsfamily("cox", link = link, bhaz = bhaz)
.brmsfamily("cox", link = link)
}

#' @rdname brmsfamily
Expand Down Expand Up @@ -1750,6 +1731,18 @@ has_thres_groups <- function(family) {
any(nzchar(groups))
}

# get group names of baseline hazard groups
get_bhaz_groups <- function(family) {
bhaz <- family_info(family, "bhaz")
unique(bhaz$groups)
}

# has the model group specific baseline hazards?
has_bhaz_groups <- function(family) {
groups <- get_bhaz_groups(family)
any(nzchar(groups))
}

has_ndt <- function(family) {
"ndt" %in% dpar_class(family_info(family, "dpars"))
}
Expand Down
2 changes: 1 addition & 1 deletion R/family-lists.R
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@
links = c("log", "identity", "softplus", "squareplus"),
dpars = c("mu"), type = "real",
ybounds = c(0, Inf), closed = c(TRUE, NA),
ad = c("weights", "subset", "cens", "trunc", "index"),
ad = c("weights", "subset", "cens", "trunc", "index", "bhaz"),
include = "fun_cox.stan",
specials = c("cox", "sbi_log", "sbi_log_cdf"),
normalized = ""
Expand Down
12 changes: 12 additions & 0 deletions R/formula-ad.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#' @param denom A vector of positive numeric values specifying
#' the denominator values from which the response rates are computed.
#' @param gr A vector of grouping indicators.
#' @param df Degrees of freedom of baseline hazard splines for Cox models.
#' @param ... For \code{resp_vreal}, vectors of real values.
#' For \code{resp_vint}, vectors of integer values. In Stan,
#' these variables will be named \code{vreal1}, \code{vreal2}, ...,
Expand Down Expand Up @@ -160,6 +161,17 @@ resp_dec <- function(x) {
class_resp_special("dec", call = match.call(), vars = nlist(dec))
}

#' @rdname addition-terms
#' @export
resp_bhaz <- function(gr = NA, df = 5, ...) {
gr <- deparse0(substitute(gr))
df <- as_one_integer(df)
args <- nlist(df, ...)
# non-power users shouldn't know they can change 'intercept'
args$intercept <- args$intercept %||% TRUE
class_resp_special("bhaz", call = match.call(), vars = nlist(gr), flags = args)
}

#' @rdname addition-terms
#' @export
resp_cens <- function(x, y2 = NA) {
Expand Down
21 changes: 17 additions & 4 deletions R/prepare_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -846,12 +846,25 @@ prepare_predictions_bhaz <- function(bframe, draws, sdata, ...) {
}
out <- list()
p <- usc(combine_prefix(bframe))
sbhaz_regex <- paste0("^sbhaz", p)
sbhaz <- prepare_draws(draws, sbhaz_regex, regex = TRUE)
Zbhaz <- sdata[[paste0("Zbhaz", p)]]
out$bhaz <- tcrossprod(sbhaz, Zbhaz)
Zcbhaz <- sdata[[paste0("Zcbhaz", p)]]
out$cbhaz <- tcrossprod(sbhaz, Zcbhaz)
if (has_bhaz_groups(bframe)) {
groups <- get_bhaz_groups(bframe)
Jgrbhaz <- sdata[[paste0("Jgrbhaz", p)]]
out$bhaz <- out$cbhaz <- matrix(nrow = nrow(draws), ncol = nrow(Zbhaz))
for (k in seq_along(groups)) {
sbhaz_regex <- paste0("^sbhaz", p, "\\[", groups[k], ",")
sbhaz <- prepare_draws(draws, sbhaz_regex, regex = TRUE)
take <- Jgrbhaz == k
out$bhaz[, take] <- tcrossprod(sbhaz, Zbhaz[take, ])
out$cbhaz[, take] <- tcrossprod(sbhaz, Zcbhaz[take, ])
}
} else {
sbhaz_regex <- paste0("^sbhaz", p)
sbhaz <- prepare_draws(draws, sbhaz_regex, regex = TRUE)
out$bhaz <- tcrossprod(sbhaz, Zbhaz)
out$cbhaz <- tcrossprod(sbhaz, Zcbhaz)
}
out
}

Expand Down
4 changes: 4 additions & 0 deletions R/priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ prior_bhaz <- function(bframe, ...) {
# the scale of sbhaz is not identified when an intercept is part of mu
# thus a sum-to-one constraint ensures identification
prior <- prior + brmsprior("dirichlet(1)", class = "sbhaz", ls = px)
if (has_bhaz_groups(bframe)) {
groups <- get_bhaz_groups(bframe)
prior <- prior + brmsprior("", class = "sbhaz", ls = px, group = groups)
}
prior
}

Expand Down
20 changes: 20 additions & 0 deletions R/rename_pars.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ rename_predictor.brmsterms <- function(x, ...) {
c(out) <- rename_Ymi(x, ...)
}
c(out) <- rename_thres(x, ...)
c(out) <- rename_bhaz(x, ...)
c(out) <- rename_family_cor_pars(x, ...)
out
}
Expand Down Expand Up @@ -201,6 +202,25 @@ rename_thres <- function(bframe, pars, ...) {
out
}

# rename baseline hazard parameters in cox models
rename_bhaz <- function(bframe, pars, ...) {
out <- list()
# renaming is only required if multiple threshold were estimated
if (!has_bhaz_groups(bframe)) {
return(out)
}
px <- check_prefix(bframe)
p <- usc(combine_prefix(px))
groups <- get_bhaz_groups(bframe)
for (k in seq_along(groups)) {
pos <- grepl(glue("^sbhaz{p}\\[{k},"), pars)
funs <- seq_len(sum(pos))
bhaz_names <- glue("sbhaz{p}[{groups[k]},{funs}]")
lc(out) <- rlist(pos, bhaz_names)
}
out
}

# helps in renaming global noise free variables
# @param meframe data.frame returned by 'frame_me'
rename_Xme <- function(bframe, pars, ...) {
Expand Down
Loading

0 comments on commit 5bb6531

Please sign in to comment.