From d295d4e8e399127340a3707913d4f6dc0884e8ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Tue, 13 Aug 2024 11:19:54 +0200 Subject: [PATCH 01/21] fix issue #1633 --- DESCRIPTION | 2 +- R/stan-likelihood.R | 7 ++++--- inst/chunks/fun_von_mises.stan | 34 --------------------------------- tests/testthat/tests.stancode.R | 5 ----- 4 files changed, 5 insertions(+), 43 deletions(-) delete mode 100644 inst/chunks/fun_von_mises.stan diff --git a/DESCRIPTION b/DESCRIPTION index 7ee510496..e304a7d06 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -101,4 +101,4 @@ Additional_repositories: VignetteBuilder: knitr, R.rsp -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 diff --git a/R/stan-likelihood.R b/R/stan-likelihood.R index 458918bf6..adeeb2da0 100644 --- a/R/stan-likelihood.R +++ b/R/stan-likelihood.R @@ -714,6 +714,8 @@ stan_log_lik_wiener <- function(bterms, resp = "", mix = "", threads = NULL, } stan_log_lik_beta <- function(bterms, resp = "", mix = "", ...) { + # TODO: check if we still require n when phi is predicted + # and check the same for other families too reqn <- stan_log_lik_adj(bterms) || nzchar(mix) || paste0("phi", mix) %in% names(bterms$dpars) p <- stan_log_lik_dpars(bterms, reqn, resp, mix) @@ -724,10 +726,9 @@ stan_log_lik_beta <- function(bterms, resp = "", mix = "", ...) { } stan_log_lik_von_mises <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms) || nzchar(mix) || - "kappa" %in% names(bterms$dpars) + reqn <- stan_log_lik_adj(bterms) || nzchar(mix) p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - sdist("von_mises2", p$mu, p$kappa) + sdist("von_mises", p$mu, p$kappa) } stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL, diff --git a/inst/chunks/fun_von_mises.stan b/inst/chunks/fun_von_mises.stan deleted file mode 100644 index 2bdd080fd..000000000 --- a/inst/chunks/fun_von_mises.stan +++ /dev/null @@ -1,34 +0,0 @@ - /* von Mises log-PDF of a single response - * for kappa > 100 the normal approximation is used - * for reasons of numerial stability - * Args: - * y: the response vector between -pi and pi - * mu: location parameter vector - * kappa: precision parameter - * Returns: - * a scalar to be added to the log posterior - */ - real von_mises2_lpdf(real y, real mu, real kappa) { - if (kappa < 100) { - return von_mises_lpdf(y | mu, kappa); - } else { - return normal_lpdf(y | mu, sqrt(1 / kappa)); - } - } - /* von Mises log-PDF of a response vector - * for kappa > 100 the normal approximation is used - * for reasons of numerial stability - * Args: - * y: the response vector between -pi and pi - * mu: location parameter vector - * kappa: precision parameter - * Returns: - * a scalar to be added to the log posterior - */ - real von_mises2_lpdf(vector y, vector mu, real kappa) { - if (kappa < 100) { - return von_mises_lpdf(y | mu, kappa); - } else { - return normal_lpdf(y | mu, sqrt(1 / kappa)); - } - } diff --git a/tests/testthat/tests.stancode.R b/tests/testthat/tests.stancode.R index d4e1e0539..274c81238 100644 --- a/tests/testthat/tests.stancode.R +++ b/tests/testthat/tests.stancode.R @@ -484,11 +484,6 @@ test_that("self-defined functions appear in the Stan code", { expect_match2(scode, "real inv_gaussian_lccdf(real y") expect_match2(scode, "real inv_gaussian_lpdf(vector y") - # von Mises models - scode <- stancode(time ~ age, data = kidney, family = von_mises) - expect_match2(scode, "real von_mises2_lpdf(real y") - expect_match2(scode, "real von_mises2_lpdf(vector y") - # zero-inflated and hurdle models expect_match2(stancode(count ~ Trt, data = epilepsy, family = "zero_inflated_poisson"), From 6970f6adcfea7749cd783cbb2c80e29fe017e255 Mon Sep 17 00:00:00 2001 From: Ven Popov Date: Tue, 13 Aug 2024 13:23:25 +0200 Subject: [PATCH 02/21] remove outdated include statement --- R/family-lists.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/family-lists.R b/R/family-lists.R index 1ad1d6150..10c3df4d8 100644 --- a/R/family-lists.R +++ b/R/family-lists.R @@ -349,7 +349,7 @@ dpars = c("mu", "kappa"), type = "real", ybounds = c(-pi, pi), closed = c(TRUE, TRUE), ad = c("weights", "subset", "cens", "trunc", "mi", "index"), - include = c("fun_tan_half.stan", "fun_von_mises.stan"), + include = c("fun_tan_half.stan"), normalized = "", # experimental use of default priors stored in families #1614 prior = function(dpar, link = "identity", ...) { From 7f77cd9f02e604a1de276fa975d02895594c1976 Mon Sep 17 00:00:00 2001 From: Adam Howes Date: Tue, 27 Aug 2024 11:45:34 +0100 Subject: [PATCH 03/21] "observatio" typo --- R/log_lik.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/log_lik.R b/R/log_lik.R index ba416a457..dd4c45c48 100644 --- a/R/log_lik.R +++ b/R/log_lik.R @@ -149,7 +149,7 @@ log_lik_pointwise <- function(data_i, draws, ...) { } # All log_lik_ functions have the same arguments structure -# @param i index of the observatio for which to compute log-lik values +# @param i index of the observation for which to compute log-lik values # @param prep A named list returned by prepare_predictions containing # all required data and posterior draws # @return a vector of length prep$ndraws containing the pointwise From 8cdec67d7e634e3a42706ec6e554ead53e952a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 9 Sep 2024 15:18:08 +0200 Subject: [PATCH 04/21] fix issue #1685 --- R/brmsframe.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/brmsframe.R b/R/brmsframe.R index a3b64007c..0075cd764 100644 --- a/R/brmsframe.R +++ b/R/brmsframe.R @@ -34,7 +34,7 @@ brmsframe.brmsterms <- function(x, data, frame = NULL, basis = NULL, ...) { # this must be a multivariate model stopifnot(is.list(frame)) x$frame <- frame - x$frame$re <- subset(x$frame$re, resp = x$resp) + x$frame$re <- subset2(x$frame$re, resp = x$resp) } data <- subset_data(data, x) x$frame$resp <- frame_resp(x, data = data) From a3b38f3092c61bfc098ac573c881d58549c964ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 11:07:58 +0200 Subject: [PATCH 05/21] fix issue #1634 --- R/brm.R | 1 + R/stan-predictor.R | 2 +- man/brm.Rd | 1 + man/brm_multiple.Rd | 1 + 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/R/brm.R b/R/brm.R index 360296318..fd7409adc 100644 --- a/R/brm.R +++ b/R/brm.R @@ -164,6 +164,7 @@ #' variational inference with independent normal distributions, #' \code{"fullrank"} for variational inference with a multivariate normal #' distribution, \code{"pathfinder"} for the pathfinder algorithm, +#' \code{"laplace"} for the laplace approximation, #' or \code{"fixed_param"} for sampling from fixed parameter #' values. Can be set globally for the current \R session via the #' \code{"brms.algorithm"} option (see \code{\link{options}}). diff --git a/R/stan-predictor.R b/R/stan-predictor.R index 9968f72c1..78365993b 100644 --- a/R/stan-predictor.R +++ b/R/stan-predictor.R @@ -2039,7 +2039,7 @@ stan_eta_combine <- function(bframe, out, threads, primitive, ...) { out$loopeta <- NULL # some links need custom Stan functions link <- bframe$family$link - link_names <- c("cauchit", "cloglog", "softplus", "squareplus", "softit") + link_names <- c("cauchit", "cloglog", "softplus", "squareplus", "softit", "tan_half") needs_link_fun <- isTRUE(link %in% link_names) if (needs_link_fun) { str_add(out$fun) <- glue(" #include 'fun_{link}.stan'\n") diff --git a/man/brm.Rd b/man/brm.Rd index 86ff43222..2724e76f8 100644 --- a/man/brm.Rd +++ b/man/brm.Rd @@ -234,6 +234,7 @@ Options are \code{"sampling"} for MCMC (the default), \code{"meanfield"} for variational inference with independent normal distributions, \code{"fullrank"} for variational inference with a multivariate normal distribution, \code{"pathfinder"} for the pathfinder algorithm, +\code{"laplace"} for the laplace approximation, or \code{"fixed_param"} for sampling from fixed parameter values. Can be set globally for the current \R session via the \code{"brms.algorithm"} option (see \code{\link{options}}).} diff --git a/man/brm_multiple.Rd b/man/brm_multiple.Rd index 938cc1391..e6070ae3c 100644 --- a/man/brm_multiple.Rd +++ b/man/brm_multiple.Rd @@ -143,6 +143,7 @@ Options are \code{"sampling"} for MCMC (the default), \code{"meanfield"} for variational inference with independent normal distributions, \code{"fullrank"} for variational inference with a multivariate normal distribution, \code{"pathfinder"} for the pathfinder algorithm, +\code{"laplace"} for the laplace approximation, or \code{"fixed_param"} for sampling from fixed parameter values. Can be set globally for the current \R session via the \code{"brms.algorithm"} option (see \code{\link{options}}).} From ab26404e85ec5728d5fecefa82dbc9ccf1b22c2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 11:37:48 +0200 Subject: [PATCH 06/21] fix issue #1672 --- R/priorsense.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/priorsense.R b/R/priorsense.R index 3353d52e9..019f2915f 100644 --- a/R/priorsense.R +++ b/R/priorsense.R @@ -59,6 +59,11 @@ log_lik_draws.brmsfit <- function(x) { #' @exportS3Method priorsense::log_prior_draws log_prior_draws.brmsfit <- function(x, log_prior_name = "lprior") { + stopifnot(length(log_prior_name) == 1) + if (!log_prior_name %in% variables(x)) { + warning2("Variable '", log_prior_name, "' was not found. ", + "Perhaps you used normalize = FALSE?") + } posterior::subset_draws( posterior::as_draws_array(x), variable = log_prior_name From c8d20003c6810f730d595336a00c808c22a37f33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 12:46:31 +0200 Subject: [PATCH 07/21] fix issue #1629 --- R/prepare_predictions.R | 2 +- man/fitted.brmsfit.Rd | 2 +- man/log_lik.brmsfit.Rd | 2 +- man/posterior_epred.brmsfit.Rd | 2 +- man/posterior_linpred.brmsfit.Rd | 2 +- man/posterior_predict.brmsfit.Rd | 2 +- man/pp_mixture.brmsfit.Rd | 2 +- man/predict.brmsfit.Rd | 2 +- man/predictive_error.brmsfit.Rd | 2 +- man/prepare_predictions.Rd | 2 +- man/residuals.brmsfit.Rd | 2 +- man/standata.brmsfit.Rd | 2 +- man/validate_newdata.Rd | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/R/prepare_predictions.R b/R/prepare_predictions.R index 9012e66fe..90c15ef6f 100644 --- a/R/prepare_predictions.R +++ b/R/prepare_predictions.R @@ -1179,7 +1179,7 @@ is.bprepnl <- function(x) { #' predictions of the grand mean when using sum coding. #' @param re_formula formula containing group-level effects to be considered in #' the prediction. If \code{NULL} (default), include all group-level effects; -#' if \code{NA}, include no group-level effects. +#' if \code{NA} or \code{~0}, include no group-level effects. #' @param allow_new_levels A flag indicating if new levels of group-level #' effects are allowed (defaults to \code{FALSE}). Only relevant if #' \code{newdata} is provided. diff --git a/man/fitted.brmsfit.Rd b/man/fitted.brmsfit.Rd index e93a09ddb..730fbc4ae 100644 --- a/man/fitted.brmsfit.Rd +++ b/man/fitted.brmsfit.Rd @@ -32,7 +32,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{scale}{Either \code{"response"} or \code{"linear"}. If \code{"response"}, results are returned on the scale diff --git a/man/log_lik.brmsfit.Rd b/man/log_lik.brmsfit.Rd index 64f43fe3f..d7a18d26a 100644 --- a/man/log_lik.brmsfit.Rd +++ b/man/log_lik.brmsfit.Rd @@ -31,7 +31,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/posterior_epred.brmsfit.Rd b/man/posterior_epred.brmsfit.Rd index d6bdc12b4..a8e57ebc4 100644 --- a/man/posterior_epred.brmsfit.Rd +++ b/man/posterior_epred.brmsfit.Rd @@ -31,7 +31,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{re.form}{Alias of \code{re_formula}.} diff --git a/man/posterior_linpred.brmsfit.Rd b/man/posterior_linpred.brmsfit.Rd index e4a57a53b..06f19879e 100644 --- a/man/posterior_linpred.brmsfit.Rd +++ b/man/posterior_linpred.brmsfit.Rd @@ -37,7 +37,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{re.form}{Alias of \code{re_formula}.} diff --git a/man/posterior_predict.brmsfit.Rd b/man/posterior_predict.brmsfit.Rd index 5c67581fe..0a860ebb4 100644 --- a/man/posterior_predict.brmsfit.Rd +++ b/man/posterior_predict.brmsfit.Rd @@ -32,7 +32,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{re.form}{Alias of \code{re_formula}.} diff --git a/man/pp_mixture.brmsfit.Rd b/man/pp_mixture.brmsfit.Rd index 446b7003d..82ff3f6a8 100644 --- a/man/pp_mixture.brmsfit.Rd +++ b/man/pp_mixture.brmsfit.Rd @@ -32,7 +32,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/predict.brmsfit.Rd b/man/predict.brmsfit.Rd index d2f731eea..c24e15411 100644 --- a/man/predict.brmsfit.Rd +++ b/man/predict.brmsfit.Rd @@ -33,7 +33,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{transform}{(Deprecated) A function or a character string naming a function to be applied on the predicted responses diff --git a/man/predictive_error.brmsfit.Rd b/man/predictive_error.brmsfit.Rd index 02c8ce1ad..520dc36f1 100644 --- a/man/predictive_error.brmsfit.Rd +++ b/man/predictive_error.brmsfit.Rd @@ -29,7 +29,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{re.form}{Alias of \code{re_formula}.} diff --git a/man/prepare_predictions.Rd b/man/prepare_predictions.Rd index 0ca06ab0a..ec75a1975 100644 --- a/man/prepare_predictions.Rd +++ b/man/prepare_predictions.Rd @@ -42,7 +42,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{allow_new_levels}{A flag indicating if new levels of group-level effects are allowed (defaults to \code{FALSE}). Only relevant if diff --git a/man/residuals.brmsfit.Rd b/man/residuals.brmsfit.Rd index 23e8f450e..b124b81a5 100644 --- a/man/residuals.brmsfit.Rd +++ b/man/residuals.brmsfit.Rd @@ -31,7 +31,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{method}{Method used to obtain predictions. Can be set to \code{"posterior_predict"} (the default), \code{"posterior_epred"}, diff --git a/man/standata.brmsfit.Rd b/man/standata.brmsfit.Rd index 0173c16ed..a5b1161c7 100644 --- a/man/standata.brmsfit.Rd +++ b/man/standata.brmsfit.Rd @@ -25,7 +25,7 @@ predictions of the grand mean when using sum coding.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{newdata2}{A named \code{list} of objects containing new data, which cannot be passed via argument \code{newdata}. Required for some objects diff --git a/man/validate_newdata.Rd b/man/validate_newdata.Rd index 85dd61b0e..05a18462c 100644 --- a/man/validate_newdata.Rd +++ b/man/validate_newdata.Rd @@ -25,7 +25,7 @@ validate_newdata( \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; -if \code{NA}, include no group-level effects.} +if \code{NA} or \code{~0}, include no group-level effects.} \item{allow_new_levels}{A flag indicating if new levels of group-level effects are allowed (defaults to \code{FALSE}). Only relevant if From ba11f9b1b17ab02434196bcd4d0782b967915ae0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 12:50:52 +0200 Subject: [PATCH 08/21] fix issue #1644 --- R/stan-prior.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/stan-prior.R b/R/stan-prior.R index 7cd1b242c..4307d48e4 100644 --- a/R/stan-prior.R +++ b/R/stan-prior.R @@ -518,7 +518,7 @@ stan_unchecked_prior <- function(prior) { # @param sample_prior take draws from priors? stan_rngprior <- function(tpar_prior, par_declars, gen_quantities, special_prior, sample_prior = "yes") { - if (!is_equal(sample_prior, "yes")) { + if (!is_equal(sample_prior, "yes") || !length(tpar_prior)) { return(list()) } tpar_prior <- strsplit(gsub(" |\\n", "", tpar_prior), ";")[[1]] From ab16c36377729d77c94b2a3313f46f6357663fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 12:55:37 +0200 Subject: [PATCH 09/21] fix issue #1656 --- R/formula-ac.R | 6 ++++-- man/car.Rd | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/R/formula-ac.R b/R/formula-ac.R index 1c5dcac4a..29b1a38d4 100644 --- a/R/formula-ac.R +++ b/R/formula-ac.R @@ -333,6 +333,7 @@ sar <- function(M, type = "lag") { #' distance <- as.matrix(dist(Grid)) #' W <- array(0, c(K, K)) #' W[distance == 1] <- 1 +#' rownames(W) <- 1:nrow(W) #' #' # generate the covariates and response data #' x1 <- rnorm(K) @@ -345,10 +346,11 @@ sar <- function(M, type = "lag") { #' prob <- exp(eta) / (1 + exp(eta)) #' size <- rep(50, K) #' y <- rbinom(n = K, size = size, prob = prob) -#' dat <- data.frame(y, size, x1, x2) +#' g <- 1:length(y) +#' dat <- data.frame(y, size, x1, x2, g) #' #' # fit a CAR model -#' fit <- brm(y | trials(size) ~ x1 + x2 + car(W), +#' fit <- brm(y | trials(size) ~ x1 + x2 + car(W, gr = g), #' data = dat, data2 = list(W = W), #' family = binomial()) #' summary(fit) diff --git a/man/car.Rd b/man/car.Rd index d675ce7eb..03554ac17 100644 --- a/man/car.Rd +++ b/man/car.Rd @@ -49,6 +49,7 @@ K <- nrow(Grid) distance <- as.matrix(dist(Grid)) W <- array(0, c(K, K)) W[distance == 1] <- 1 +rownames(W) <- 1:nrow(W) # generate the covariates and response data x1 <- rnorm(K) @@ -61,10 +62,11 @@ eta <- x1 + x2 + phi prob <- exp(eta) / (1 + exp(eta)) size <- rep(50, K) y <- rbinom(n = K, size = size, prob = prob) -dat <- data.frame(y, size, x1, x2) +g <- 1:length(y) +dat <- data.frame(y, size, x1, x2, g) # fit a CAR model -fit <- brm(y | trials(size) ~ x1 + x2 + car(W), +fit <- brm(y | trials(size) ~ x1 + x2 + car(W, gr = g), data = dat, data2 = list(W = W), family = binomial()) summary(fit) From 34f2a0b23cd47c54f5b0cb398b73f059771919a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 13:01:49 +0200 Subject: [PATCH 10/21] fix issue #1648 --- R/misc.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/misc.R b/R/misc.R index f0060b9f8..fd0e34d94 100644 --- a/R/misc.R +++ b/R/misc.R @@ -615,7 +615,8 @@ rename <- function(x, pattern = NULL, replacement = NULL, dup <- duplicated(out) if (check_dup && any(dup)) { dup <- x[out %in% out[dup]] - stop2("Internal renaming led to duplicated names. \n", + stop2("Internal renaming led to duplicated names. ", + "Consider renaming your variables to have different suffixes.\n", "Occured for: ", collapse_comma(dup)) } out From 22550dcf4c72ae817ea08c89b47067b2db00ffe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 13:08:22 +0200 Subject: [PATCH 11/21] fix issue #1675 --- R/priors.R | 20 +++++++++++--------- man/R2D2.Rd | 13 ++++++++----- man/horseshoe.Rd | 7 +++---- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/R/priors.R b/R/priors.R index 1912a7fad..d6acd26a1 100644 --- a/R/priors.R +++ b/R/priors.R @@ -1973,10 +1973,9 @@ eval_dirichlet <- function(prior, len = NULL, env = NULL) { #' Regularized horseshoe priors in \pkg{brms} #' -#' Function used to set up regularized horseshoe priors and related -#' hierarchical shrinkage priors for population-level effects in \pkg{brms}. The -#' function does not evaluate its arguments -- it exists purely to help set up -#' the model. +#' Function used to set up regularized horseshoe priors and related hierarchical +#' shrinkage priors in \pkg{brms}. The function does not evaluate its arguments +#' -- it exists purely to help set up the model. #' #' @param df Degrees of freedom of student-t prior of the #' local shrinkage parameters. Defaults to \code{1}. @@ -2130,9 +2129,8 @@ horseshoe <- function(df = 1, scale_global = 1, df_global = 1, #' R2D2 Priors in \pkg{brms} #' -#' Function used to set up R2D2 priors for population-level effects in -#' \pkg{brms}. The function does not evaluate its arguments -- it exists purely -#' to help set up the model. +#' Function used to set up R2D2(M2) priors in \pkg{brms}. The function does +#' not evaluate its arguments -- it exists purely to help set up the model. #' #' @param mean_R2 Mean of the Beta prior on the coefficient of determination R^2. #' @param prec_R2 Precision of the Beta prior on the coefficient of determination R^2. @@ -2150,14 +2148,18 @@ horseshoe <- function(df = 1, scale_global = 1, df_global = 1, #' See the Examples section below. #' #' @details -#' Currently, the following classes support the R2D2 prior: \code{b} +#' Currently, the following classes support the R2D2(M2) prior: \code{b} #' (overall regression coefficients), \code{sds} (SDs of smoothing splines), #' \code{sdgp} (SDs of Gaussian processes), \code{ar} (autoregressive #' coefficients), \code{ma} (moving average coefficients), \code{sderr} (SD of #' latent residuals), \code{sdcar} (SD of spatial CAR structures), \code{sd} #' (SD of varying coefficients). #' -#' Even when the R2D2 prior is applied to multiple parameter classes at once, +#' When the prior is only applied to parameter class \code{b}, it is equivalent +#' to the original R2D2 prior (with Gaussian kernel). When the prior is also +#' applied to other parameter classes, it is equivalent to the R2D2M2 prior. +#' +#' Even when the R2D2(M2) prior is applied to multiple parameter classes at once, #' the concentration vector (argument \code{cons_D2}) has to be provided #' jointly in the the one instance of the prior where \code{main = TRUE}. The #' order in which the elements of concentration vector correspond to the diff --git a/man/R2D2.Rd b/man/R2D2.Rd index 95e73f8f4..408f7c393 100644 --- a/man/R2D2.Rd +++ b/man/R2D2.Rd @@ -27,19 +27,22 @@ Arguments given in other instances of the prior will be ignored. See the Examples section below.} } \description{ -Function used to set up R2D2 priors for population-level effects in -\pkg{brms}. The function does not evaluate its arguments -- it exists purely -to help set up the model. +Function used to set up R2D2(M2) priors in \pkg{brms}. The function does +not evaluate its arguments -- it exists purely to help set up the model. } \details{ -Currently, the following classes support the R2D2 prior: \code{b} +Currently, the following classes support the R2D2(M2) prior: \code{b} (overall regression coefficients), \code{sds} (SDs of smoothing splines), \code{sdgp} (SDs of Gaussian processes), \code{ar} (autoregressive coefficients), \code{ma} (moving average coefficients), \code{sderr} (SD of latent residuals), \code{sdcar} (SD of spatial CAR structures), \code{sd} (SD of varying coefficients). - Even when the R2D2 prior is applied to multiple parameter classes at once, + When the prior is only applied to parameter class \code{b}, it is equivalent + to the original R2D2 prior (with Gaussian kernel). When the prior is also + applied to other parameter classes, it is equivalent to the R2D2M2 prior. + + Even when the R2D2(M2) prior is applied to multiple parameter classes at once, the concentration vector (argument \code{cons_D2}) has to be provided jointly in the the one instance of the prior where \code{main = TRUE}. The order in which the elements of concentration vector correspond to the diff --git a/man/horseshoe.Rd b/man/horseshoe.Rd index d9ba6c91d..035557ec1 100644 --- a/man/horseshoe.Rd +++ b/man/horseshoe.Rd @@ -61,10 +61,9 @@ A character string obtained by \code{match.call()} with additional arguments. } \description{ -Function used to set up regularized horseshoe priors and related -hierarchical shrinkage priors for population-level effects in \pkg{brms}. The -function does not evaluate its arguments -- it exists purely to help set up -the model. +Function used to set up regularized horseshoe priors and related hierarchical +shrinkage priors in \pkg{brms}. The function does not evaluate its arguments +-- it exists purely to help set up the model. } \details{ The horseshoe prior is a special shrinkage prior initially proposed by From 96902a7eaba0106169bfa7bf4c5b461932ef3440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 13:14:00 +0200 Subject: [PATCH 12/21] fix issue #1655 --- R/priors.R | 8 ++++++++ man/R2D2.Rd | 6 +++++- man/horseshoe.Rd | 4 ++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/R/priors.R b/R/priors.R index d6acd26a1..83d7fad62 100644 --- a/R/priors.R +++ b/R/priors.R @@ -2051,6 +2051,10 @@ eval_dirichlet <- function(prior, len = NULL, env = NULL) { #' See the documentation of \code{\link{brm}} for instructions #' on how to increase \code{adapt_delta}. #' +#' The prior does not account for scale differences of the terms it is +#' applied on. Accordingly, please make sure that all these terms have a +#' comparable scale to ensure that shrinkage is applied properly. +#' #' Currently, the following classes support the horseshoe prior: \code{b} #' (overall regression coefficients), \code{sds} (SDs of smoothing splines), #' \code{sdgp} (SDs of Gaussian processes), \code{ar} (autoregressive @@ -2148,6 +2152,10 @@ horseshoe <- function(df = 1, scale_global = 1, df_global = 1, #' See the Examples section below. #' #' @details +#' The prior does not account for scale differences of the terms it is +#' applied on. Accordingly, please make sure that all these terms have a +#' comparable scale to ensure that shrinkage is applied properly. +#' #' Currently, the following classes support the R2D2(M2) prior: \code{b} #' (overall regression coefficients), \code{sds} (SDs of smoothing splines), #' \code{sdgp} (SDs of Gaussian processes), \code{ar} (autoregressive diff --git a/man/R2D2.Rd b/man/R2D2.Rd index 408f7c393..5f03c9147 100644 --- a/man/R2D2.Rd +++ b/man/R2D2.Rd @@ -31,7 +31,11 @@ Function used to set up R2D2(M2) priors in \pkg{brms}. The function does not evaluate its arguments -- it exists purely to help set up the model. } \details{ -Currently, the following classes support the R2D2(M2) prior: \code{b} +The prior does not account for scale differences of the terms it is + applied on. Accordingly, please make sure that all these terms have a + comparable scale to ensure that shrinkage is applied properly. + + Currently, the following classes support the R2D2(M2) prior: \code{b} (overall regression coefficients), \code{sds} (SDs of smoothing splines), \code{sdgp} (SDs of Gaussian processes), \code{ar} (autoregressive coefficients), \code{ma} (moving average coefficients), \code{sderr} (SD of diff --git a/man/horseshoe.Rd b/man/horseshoe.Rd index 035557ec1..5e9c83138 100644 --- a/man/horseshoe.Rd +++ b/man/horseshoe.Rd @@ -103,6 +103,10 @@ The horseshoe prior is a special shrinkage prior initially proposed by See the documentation of \code{\link{brm}} for instructions on how to increase \code{adapt_delta}. + The prior does not account for scale differences of the terms it is + applied on. Accordingly, please make sure that all these terms have a + comparable scale to ensure that shrinkage is applied properly. + Currently, the following classes support the horseshoe prior: \code{b} (overall regression coefficients), \code{sds} (SDs of smoothing splines), \code{sdgp} (SDs of Gaussian processes), \code{ar} (autoregressive From 407cad84ee9f5dcbf921cb9794be4772be723e9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 16:31:21 +0200 Subject: [PATCH 13/21] fix issue #1652 --- R/data-helpers.R | 4 ++++ R/prepare_predictions.R | 9 +++++---- man/fitted.brmsfit.Rd | 9 +++++---- man/get_refmodel.brmsfit.Rd | 9 +++++---- man/log_lik.brmsfit.Rd | 9 +++++---- man/loo_moment_match.brmsfit.Rd | 9 +++++---- man/posterior_epred.brmsfit.Rd | 9 +++++---- man/posterior_linpred.brmsfit.Rd | 9 +++++---- man/posterior_predict.brmsfit.Rd | 9 +++++---- man/pp_check.brmsfit.Rd | 9 +++++---- man/pp_mixture.brmsfit.Rd | 9 +++++---- man/predict.brmsfit.Rd | 9 +++++---- man/predictive_error.brmsfit.Rd | 9 +++++---- man/prepare_predictions.Rd | 9 +++++---- man/psis.brmsfit.Rd | 9 +++++---- man/reloo.brmsfit.Rd | 9 +++++---- man/residuals.brmsfit.Rd | 9 +++++---- man/standata.brmsfit.Rd | 9 +++++---- 18 files changed, 89 insertions(+), 68 deletions(-) diff --git a/R/data-helpers.R b/R/data-helpers.R index b8a5eca7f..d63d9a81e 100644 --- a/R/data-helpers.R +++ b/R/data-helpers.R @@ -558,6 +558,10 @@ validate_newdata <- function( new_levels <- get_levels(bterms, data = newdata) for (g in names(old_levels)) { unknown_levels <- setdiff(new_levels[[g]], old_levels[[g]]) + # NA is not found by get_levels but still behaves like a new level (#1652) + if (anyNA(newdata[[g]])) { + c(unknown_levels) <- NA + } if (length(unknown_levels)) { unknown_levels <- collapse_comma(unknown_levels) stop2( diff --git a/R/prepare_predictions.R b/R/prepare_predictions.R index 90c15ef6f..508d81725 100644 --- a/R/prepare_predictions.R +++ b/R/prepare_predictions.R @@ -1173,10 +1173,11 @@ is.bprepnl <- function(x) { #' #' @param x An \R object typically of class \code{'brmsfit'}. #' @param newdata An optional data.frame for which to evaluate predictions. If -#' \code{NULL} (default), the original data of the model is used. -#' \code{NA} values within factors are interpreted as if all dummy -#' variables of this factor are zero. This allows, for instance, to make -#' predictions of the grand mean when using sum coding. +#' \code{NULL} (default), the original data of the model is used. \code{NA} +#' values within factors (excluding grouping variables) are interpreted as if +#' all dummy variables of this factor are zero. This allows, for instance, to +#' make predictions of the grand mean when using sum coding. \code{NA} values +#' within grouping variables are treated as a new level. #' @param re_formula formula containing group-level effects to be considered in #' the prediction. If \code{NULL} (default), include all group-level effects; #' if \code{NA} or \code{~0}, include no group-level effects. diff --git a/man/fitted.brmsfit.Rd b/man/fitted.brmsfit.Rd index 730fbc4ae..9811eef2e 100644 --- a/man/fitted.brmsfit.Rd +++ b/man/fitted.brmsfit.Rd @@ -25,10 +25,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/get_refmodel.brmsfit.Rd b/man/get_refmodel.brmsfit.Rd index 9a6050272..60ec62f2c 100644 --- a/man/get_refmodel.brmsfit.Rd +++ b/man/get_refmodel.brmsfit.Rd @@ -19,10 +19,11 @@ get_refmodel.brmsfit( \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/log_lik.brmsfit.Rd b/man/log_lik.brmsfit.Rd index d7a18d26a..67dafa487 100644 --- a/man/log_lik.brmsfit.Rd +++ b/man/log_lik.brmsfit.Rd @@ -24,10 +24,11 @@ \item{object}{A fitted model object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/loo_moment_match.brmsfit.Rd b/man/loo_moment_match.brmsfit.Rd index fd0f45595..e6e6a418e 100644 --- a/man/loo_moment_match.brmsfit.Rd +++ b/man/loo_moment_match.brmsfit.Rd @@ -33,10 +33,11 @@ See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/posterior_epred.brmsfit.Rd b/man/posterior_epred.brmsfit.Rd index a8e57ebc4..75cc93724 100644 --- a/man/posterior_epred.brmsfit.Rd +++ b/man/posterior_epred.brmsfit.Rd @@ -24,10 +24,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/posterior_linpred.brmsfit.Rd b/man/posterior_linpred.brmsfit.Rd index 06f19879e..683fa82ff 100644 --- a/man/posterior_linpred.brmsfit.Rd +++ b/man/posterior_linpred.brmsfit.Rd @@ -30,10 +30,11 @@ If \code{TRUE}, draws of the transformed linear predictor, that is, after applying the inverse link function are returned.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/posterior_predict.brmsfit.Rd b/man/posterior_predict.brmsfit.Rd index 0a860ebb4..3e3f89740 100644 --- a/man/posterior_predict.brmsfit.Rd +++ b/man/posterior_predict.brmsfit.Rd @@ -25,10 +25,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/pp_check.brmsfit.Rd b/man/pp_check.brmsfit.Rd index fadb2f1d2..737333f86 100644 --- a/man/pp_check.brmsfit.Rd +++ b/man/pp_check.brmsfit.Rd @@ -49,10 +49,11 @@ Only used for ppc types having an \code{x} argument and ignored otherwise.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/pp_mixture.brmsfit.Rd b/man/pp_mixture.brmsfit.Rd index 82ff3f6a8..e7be94760 100644 --- a/man/pp_mixture.brmsfit.Rd +++ b/man/pp_mixture.brmsfit.Rd @@ -25,10 +25,11 @@ pp_mixture(x, ...) \item{x}{An \R object usually of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/predict.brmsfit.Rd b/man/predict.brmsfit.Rd index c24e15411..a197c395f 100644 --- a/man/predict.brmsfit.Rd +++ b/man/predict.brmsfit.Rd @@ -26,10 +26,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/predictive_error.brmsfit.Rd b/man/predictive_error.brmsfit.Rd index 520dc36f1..6e3485630 100644 --- a/man/predictive_error.brmsfit.Rd +++ b/man/predictive_error.brmsfit.Rd @@ -22,10 +22,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/prepare_predictions.Rd b/man/prepare_predictions.Rd index ec75a1975..4062895fd 100644 --- a/man/prepare_predictions.Rd +++ b/man/prepare_predictions.Rd @@ -35,10 +35,11 @@ prepare_predictions(x, ...) \item{x}{An \R object typically of class \code{'brmsfit'}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/psis.brmsfit.Rd b/man/psis.brmsfit.Rd index d0f0b90f7..2af401ca9 100644 --- a/man/psis.brmsfit.Rd +++ b/man/psis.brmsfit.Rd @@ -13,10 +13,11 @@ Argument is named "log_ratios" to match the argument name of the \code{\link[loo:psis]{loo::psis}} generic function.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/reloo.brmsfit.Rd b/man/reloo.brmsfit.Rd index 8a51e82f5..0e9100f32 100644 --- a/man/reloo.brmsfit.Rd +++ b/man/reloo.brmsfit.Rd @@ -36,10 +36,11 @@ See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{resp}{Optional names of response variables. If specified, predictions are performed only for the specified response variables.} diff --git a/man/residuals.brmsfit.Rd b/man/residuals.brmsfit.Rd index b124b81a5..30bcca1bf 100644 --- a/man/residuals.brmsfit.Rd +++ b/man/residuals.brmsfit.Rd @@ -24,10 +24,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; diff --git a/man/standata.brmsfit.Rd b/man/standata.brmsfit.Rd index a5b1161c7..4af4dc9f1 100644 --- a/man/standata.brmsfit.Rd +++ b/man/standata.brmsfit.Rd @@ -18,10 +18,11 @@ \item{object}{An object of class \code{brmsfit}.} \item{newdata}{An optional data.frame for which to evaluate predictions. If -\code{NULL} (default), the original data of the model is used. -\code{NA} values within factors are interpreted as if all dummy -variables of this factor are zero. This allows, for instance, to make -predictions of the grand mean when using sum coding.} +\code{NULL} (default), the original data of the model is used. \code{NA} +values within factors (excluding grouping variables) are interpreted as if +all dummy variables of this factor are zero. This allows, for instance, to +make predictions of the grand mean when using sum coding. \code{NA} values +within grouping variables are treated as a new level.} \item{re_formula}{formula containing group-level effects to be considered in the prediction. If \code{NULL} (default), include all group-level effects; From 39a6feea9897c1d3a23b78bb1d868429e6cd499a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 16:34:02 +0200 Subject: [PATCH 14/21] Fix issue #1668 --- vignettes/brms_multivariate.Rmd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vignettes/brms_multivariate.Rmd b/vignettes/brms_multivariate.Rmd index ca7cd3fee..4ee6fd09a 100644 --- a/vignettes/brms_multivariate.Rmd +++ b/vignettes/brms_multivariate.Rmd @@ -75,8 +75,8 @@ summary(fit1) The summary output of multivariate models closely resembles those of univariate models, except that the parameters now have the corresponding response variable -as prefix. Within dams, tarsus length and back color seem to be negatively -correlated, while within fosternests the opposite is true. This indicates +as prefix. Across dams, tarsus length and back color seem to be negatively +correlated, while across fosternests the opposite is true. This indicates differential effects of genetic and environmental factors on these two characteristics. Further, the small residual correlation `rescor(tarsus, back)` on the bottom of the output indicates that there is little unmodeled dependency From a04ac6824c62c5fc81c8f845a77e4380cbc75fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 17:12:43 +0200 Subject: [PATCH 15/21] feature issue #1635 --- DESCRIPTION | 4 ++-- R/numeric-helpers.R | 23 +++++++++++++++++++++++ R/prepare_predictions.R | 2 +- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index e304a7d06..ba96b4de4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,8 +2,8 @@ Package: brms Encoding: UTF-8 Type: Package Title: Bayesian Regression Models using 'Stan' -Version: 2.21.6 -Date: 2024-06-06 +Version: 2.21.7 +Date: 2024-09-12 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), diff --git a/R/numeric-helpers.R b/R/numeric-helpers.R index 26d920948..8a5b60431 100644 --- a/R/numeric-helpers.R +++ b/R/numeric-helpers.R @@ -216,3 +216,26 @@ log1m_inv_softit <- function(x) { y <- log1p_exp(x) -log1p(y) } + +# names of built-in stan functons reimplemented in R within brms +names_stan_functions <- function() { + c("logit", "inv_logit", "cloglog", "inv_cloglog", "Phi", "incgamma", + "square", "cbrt", "exp2", "pow", "inv", "inv_sqrt", "inv_square", + "hypot", "log1m", "step", "logm1", "expp1", "logit_scaled", + "inv_logit_scaled", "multiply_log", "log1p_exp", "log1m_exp", + "log_diff_exp", "log_sum_exp", "log_mean_exp", "log_expm1", + "log_inv_logit", "log1m_inv_logit", "scale_unit", "fabs", "log_softmax", + "softmax", "inv_odds", "softit", "inv_softit", "log_inv_softit", + "log1m_inv_softit") +} + +# create an environement with all the reimplemented stan functions in it +# see issue #1635 for discussion of this approach +env_stan_functions <- function(...) { + env <- new.env(...) + brms_env <- asNamespace("brms") + for (f in names_stan_functions()) { + env[[f]] <- get(f, brms_env) + } + env +} diff --git a/R/prepare_predictions.R b/R/prepare_predictions.R index 508d81725..0d59eb993 100644 --- a/R/prepare_predictions.R +++ b/R/prepare_predictions.R @@ -196,7 +196,7 @@ prepare_predictions.bframenl <- function(x, draws, sdata, ...) { out <- list( family = x$family, nlform = x$formula[[2]], - env = environment(x$formula), + env = env_stan_functions(parent = environment(x$formula)), ndraws = nrow(draws), nobs = sdata[[paste0("N", usc(x$resp))]], used_nlpars = x$used_nlpars, From 87feb9d251a27f56884fb29132e7629659cd53b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 12 Sep 2024 17:18:04 +0200 Subject: [PATCH 16/21] fix issue #1665 --- R/loo.R | 9 +++++---- man/loo.brmsfit.Rd | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/R/loo.R b/R/loo.R index 3789da20c..88a96c21f 100644 --- a/R/loo.R +++ b/R/loo.R @@ -34,10 +34,11 @@ #' See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details. #' @param save_psis Should the \code{"psis"} object created internally be saved #' in the returned object? For more details see \code{\link[loo:loo]{loo}}. -#' @param moment_match_args Optional \code{list} of additional arguments passed to -#' \code{\link{loo_moment_match}}. -#' @param reloo_args Optional \code{list} of additional arguments passed to -#' \code{\link{reloo}}. +#' @param moment_match_args Optional named \code{list} of additional arguments +#' passed to \code{\link{loo_moment_match}}. +#' @param reloo_args Optional named \code{list} of additional arguments passed to +#' \code{\link{reloo}}. This can be useful, among others, to control +#' how many chains, iterations, etc. to use for the fitted sub-models. #' @param model_names If \code{NULL} (the default) will use model names #' derived from deparsing the call. Otherwise will use the passed #' values as model names. diff --git a/man/loo.brmsfit.Rd b/man/loo.brmsfit.Rd index 11ead8904..813b11d20 100644 --- a/man/loo.brmsfit.Rd +++ b/man/loo.brmsfit.Rd @@ -62,11 +62,12 @@ See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details.} \item{save_psis}{Should the \code{"psis"} object created internally be saved in the returned object? For more details see \code{\link[loo:loo]{loo}}.} -\item{moment_match_args}{Optional \code{list} of additional arguments passed to -\code{\link{loo_moment_match}}.} +\item{moment_match_args}{Optional named \code{list} of additional arguments +passed to \code{\link{loo_moment_match}}.} -\item{reloo_args}{Optional \code{list} of additional arguments passed to -\code{\link{reloo}}.} +\item{reloo_args}{Optional named \code{list} of additional arguments passed to +\code{\link{reloo}}. This can be useful, among others, to control +how many chains, iterations, etc. to use for the fitted sub-models.} \item{model_names}{If \code{NULL} (the default) will use model names derived from deparsing the call. Otherwise will use the passed From 700f7d7d9a3f31aad9a3f1ff3e21a8e38e318ccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Fri, 13 Sep 2024 18:46:04 +0200 Subject: [PATCH 17/21] fix failing tests --- tests/testthat/tests.misc.R | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/testthat/tests.misc.R b/tests/testthat/tests.misc.R index e17c1dafe..e684e179b 100644 --- a/tests/testthat/tests.misc.R +++ b/tests/testthat/tests.misc.R @@ -18,14 +18,11 @@ test_that("rmNULL removes all NULL entries", { test_that("rename returns an error on duplicated names", { expect_error(rename(c(letters[1:4],"a()","a["), check_dup = TRUE), fixed = TRUE, - paste("Internal renaming led to duplicated names.", - "\nOccured for: 'a', 'a()', 'a['")) + paste("Occured for: 'a', 'a()', 'a['")) expect_error(rename(c("aDb","a/b","b"), check_dup = TRUE), fixed = TRUE, - paste("Internal renaming led to duplicated names.", - "\nOccured for: 'aDb', 'a/b'")) + paste("Occured for: 'aDb', 'a/b'")) expect_error(rename(c("log(a,b)","logab","bac","ba"), check_dup = TRUE), fixed = TRUE, - paste("Internal renaming led to duplicated names.", - "\nOccured for: 'log(a,b)', 'logab'")) + paste("Occured for: 'log(a,b)', 'logab'")) }) test_that("rename perform correct renaming", { From f6632d70f95cbf8dbdc072a89e28beadc9d6ec6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Fri, 13 Sep 2024 19:53:23 +0200 Subject: [PATCH 18/21] feature issue #1684 --- DESCRIPTION | 2 +- NEWS.md | 1 + R/backends.R | 50 +++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index ba96b4de4..2819b8f87 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,7 @@ Package: brms Encoding: UTF-8 Type: Package Title: Bayesian Regression Models using 'Stan' -Version: 2.21.7 +Version: 2.21.8 Date: 2024-09-12 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", diff --git a/NEWS.md b/NEWS.md index 9b2af712b..027c9b35c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ### New Features +* Support futures for parallelization in the `cmdstanr` backend. (#1684) * Add method `loo_epred` thanks to Aki Vehtari. (#1641) * Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354) diff --git a/R/backends.R b/R/backends.R index 6d83f3017..3000383f6 100644 --- a/R/backends.R +++ b/R/backends.R @@ -166,6 +166,7 @@ fit_model <- function(model, backend, ...) { } else if (is.character(init) && !init %in% c("random", "0")) { init <- get(init, mode = "function", envir = parent.frame()) } + future <- future && algorithm %in% "sampling" args <- nlist( object = model, data = sdata, iter, seed, init = init, pars = exclude, include = FALSE @@ -187,7 +188,7 @@ fit_model <- function(model, backend, ...) { warning2("Argument 'cores' is ignored when using 'future'.") } args$chains <- 1L - futures <- fits <- vector("list", chains) + out <- futures <- vector("list", chains) for (i in seq_len(chains)) { args$chain_id <- i if (is.list(init)) { @@ -200,10 +201,10 @@ fit_model <- function(model, backend, ...) { ) } for (i in seq_len(chains)) { - fits[[i]] <- future::value(futures[[i]]) + out[[i]] <- future::value(futures[[i]]) } - out <- rstan::sflist2stanfit(fits) - rm(futures, fits) + out <- rstan::sflist2stanfit(out) + rm(futures) } else { c(args) <- nlist(chains, cores) out <- do_call(rstan::sampling, args) @@ -239,9 +240,7 @@ fit_model <- function(model, backend, ...) { } else if (is_equal(init, "0")) { init <- 0 } - if (future) { - stop2("Argument 'future' is not supported by backend 'cmdstanr'.") - } + future <- future && algorithm %in% "sampling" args <- nlist(data = sdata, seed, init) if (use_opencl(opencl)) { args$opencl_ids <- opencl$ids @@ -279,7 +278,30 @@ fit_model <- function(model, backend, ...) { if (use_threading(threads)) { args$threads_per_chain <- threads$threads } - out <- do_call(model$sample, args) + if (future) { + if (cores > 1L) { + warning2("Argument 'cores' is ignored when using 'future'.") + } + args$chains <- 1L + out <- futures <- vector("list", chains) + for (i in seq_len(chains)) { + args$chain_ids <- i + if (is.list(init)) { + args$init <- init[i] + } + futures[[i]] <- future::future( + brms::do_call(model$sample, args), + packages = "cmdstanr", + seed = TRUE + ) + } + for (i in seq_len(chains)) { + out[[i]] <- future::value(futures[[i]]) + } + rm(futures) + } else { + out <- do_call(model$sample, args) + } } else if (algorithm %in% c("fullrank", "meanfield")) { c(args) <- nlist(iter, algorithm) if (use_threading(threads)) { @@ -300,8 +322,18 @@ fit_model <- function(model, backend, ...) { stop2("Algorithm '", algorithm, "' is not supported.") } + if (future) { + # 'out' is a list of fitted models + output_files <- ulapply(out, function(x) x$output_files()) + stan_variables <- out[[1]]$metadata()$stan_variables + } else { + # 'out' is a single fitted model + output_files <- out$output_files() + stan_variables <- out$metadata()$stan_variables + } + out <- read_csv_as_stanfit( - out$output_files(), variables = out$metadata()$stan_variables, + output_files, variables = stan_variables, model = model, exclude = exclude, algorithm = algorithm ) From 496a6d8c4b24ddce8301b7d253689b459291ee77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Fri, 13 Sep 2024 20:38:37 +0200 Subject: [PATCH 19/21] feature issue #1674 --- R/loo_predict.R | 13 ++++++++---- R/pp_check.R | 44 +++++++++++++++++++---------------------- man/pp_check.brmsfit.Rd | 2 +- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/R/loo_predict.R b/R/loo_predict.R index ac267b723..03dbbbd07 100644 --- a/R/loo_predict.R +++ b/R/loo_predict.R @@ -64,7 +64,9 @@ loo_predict.brmsfit <- function(object, type = c("mean", "var", "quantile"), type <- match.arg(type) if (is.null(psis_object)) { message("Running PSIS to compute weights") - psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...) + # run loo instead of psis to allow for moment matching + loo_object <- loo(object, resp = resp, save_psis = TRUE, ...) + psis_object <- loo_object$psis_object } preds <- posterior_predict(object, resp = resp, ...) E_loo_value(preds, psis_object, type = type, probs = probs) @@ -79,10 +81,11 @@ loo_epred.brmsfit <- function(object, type = c("mean", "var", "quantile"), probs = 0.5, psis_object = NULL, resp = NULL, ...) { type <- match.arg(type) - # stopifnot_resp(object, resp) if (is.null(psis_object)) { message("Running PSIS to compute weights") - psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...) + # run loo instead of psis to allow for moment matching + loo_object <- loo(object, resp = resp, save_psis = TRUE, ...) + psis_object <- loo_object$psis_object } preds <- posterior_epred(object, resp = resp, ...) E_loo_value(preds, psis_object, type = type, probs = probs) @@ -106,7 +109,9 @@ loo_linpred.brmsfit <- function(object, type = c("mean", "var", "quantile"), type <- match.arg(type) if (is.null(psis_object)) { message("Running PSIS to compute weights") - psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...) + # run loo instead of psis to allow for moment matching + loo_object <- loo(object, resp = resp, save_psis = TRUE, ...) + psis_object <- loo_object$psis_object } preds <- posterior_linpred(object, resp = resp, ...) E_loo_value(preds, psis_object, type = type, probs = probs) diff --git a/R/pp_check.R b/R/pp_check.R index 096257dce..13e9d7751 100644 --- a/R/pp_check.R +++ b/R/pp_check.R @@ -16,7 +16,7 @@ #' If \code{NULL} all draws are used. If not specified, #' the number of posterior draws is chosen automatically. #' Ignored if \code{draw_ids} is not \code{NULL}. -#' @param prefix The prefix of the \pkg{bayesplot} function to be applied. +#' @param prefix The prefix of the \pkg{bayesplot} function to be applied. #' Either `"ppc"` (posterior predictive check; the default) #' or `"ppd"` (posterior predictive distribution), the latter being the same #' as the former except that the observed data is not shown for `"ppd"`. @@ -53,7 +53,7 @@ #' #' ## get an overview of all valid types #' pp_check(fit, type = "xyz") -#' +#' #' ## get a plot without the observed data #' pp_check(fit, prefix = "ppd") #' } @@ -62,7 +62,7 @@ #' @export pp_check #' @export pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd"), - group = NULL, x = NULL, newdata = NULL, resp = NULL, + group = NULL, x = NULL, newdata = NULL, resp = NULL, draw_ids = NULL, nsamples = NULL, subset = NULL, ...) { dots <- list(...) if (missing(type)) { @@ -124,7 +124,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd "error_scatter_avg", "error_scatter_avg_vs_x", "intervals", "intervals_grouped", "loo_intervals", "loo_pit", "loo_pit_overlay", - "loo_pit_qq", "loo_ribbon", + "loo_pit_qq", "loo_ribbon", 'pit_ecdf', 'pit_ecdf_grouped', "ribbon", "ribbon_grouped", "rootogram", "scatter_avg", "scatter_avg_grouped", @@ -147,7 +147,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd y <- NULL if (prefix == "ppc") { # y is ignored in prefix 'ppd' plots - y <- get_y(object, resp = resp, newdata = newdata, ...) + y <- get_y(object, resp = resp, newdata = newdata, ...) } draw_ids <- validate_draw_ids(object, draw_ids, ndraws) pred_args <- list( @@ -167,7 +167,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd object, newdata = newdata, resp = resp, re_formula = NA, check_response = TRUE, ... ) - + # prepare plotting arguments ppc_args <- list() if (prefix == "ppc") { @@ -185,17 +185,16 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd ppc_args$x <- as.numeric(ppc_args$x) } } - if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) { - ppc_args$psis_object <- do_call( - compute_loo, c(pred_args, criterion = "psis") - ) - } if ("lw" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) { - ppc_args$lw <- weights( - do_call(compute_loo, c(pred_args, criterion = "psis")) - ) + # run loo instead of psis to allow for moment matching + loo_object <- do_call(loo, c(pred_args, save_psis = TRUE)) + ppc_args$lw <- weights(loo_object$psis_object, log = TRUE) + } else if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) { + # some PPCs may only support 'psis_object' but not 'lw' for whatever reason + loo_object <- do_call(loo, c(pred_args, save_psis = TRUE)) + ppc_args$psis_object <- loo_object$psis_object } - + # censored responses are misleading when displayed in pp_check bterms <- brmsterms(object$formula) cens <- get_cens(bterms, data, resp = resp) @@ -213,20 +212,17 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd if (!is.null(ppc_args$x)) { ppc_args$x <- ppc_args$x[take] } - if (!is.null(ppc_args$psis_object)) { - # tidier to re-compute with subset - psis_args <- c(pred_args, criterion = "psis") - psis_args$newdata <- data[take, ] - ppc_args$psis_object <- do_call(compute_loo, psis_args) - } if (!is.null(ppc_args$lw)) { - ppc_args$lw <- ppc_args$lw[,take] + ppc_args$lw <- ppc_args$lw[, take] + } else if (!is.null(ppc_args$psis_object)) { + # we only need the log weights so the rest can remain unchanged + ppc_args$psis_object$log_weights <- ppc_args$psis_object$log_weights[, take] } } - + # most ... arguments are meant for the prediction function for_pred <- names(dots) %in% names(formals(prepare_predictions.brmsfit)) ppc_args <- c(ppc_args, dots[!for_pred]) - + do_call(ppc_fun, ppc_args) } diff --git a/man/pp_check.brmsfit.Rd b/man/pp_check.brmsfit.Rd index 737333f86..f81881fb8 100644 --- a/man/pp_check.brmsfit.Rd +++ b/man/pp_check.brmsfit.Rd @@ -35,7 +35,7 @@ If \code{NULL} all draws are used. If not specified, the number of posterior draws is chosen automatically. Ignored if \code{draw_ids} is not \code{NULL}.} -\item{prefix}{The prefix of the \pkg{bayesplot} function to be applied. +\item{prefix}{The prefix of the \pkg{bayesplot} function to be applied. Either `"ppc"` (posterior predictive check; the default) or `"ppd"` (posterior predictive distribution), the latter being the same as the former except that the observed data is not shown for `"ppd"`.} From b4410607713093f3f517a44f2fc59a65ed603c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 16 Sep 2024 09:44:18 +0200 Subject: [PATCH 20/21] fix post-hoc basis computation --- R/standata.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/standata.R b/R/standata.R index 80487c79b..06f927415 100644 --- a/R/standata.R +++ b/R/standata.R @@ -198,7 +198,8 @@ standata.brmsfit <- function(object, newdata = NULL, re_formula = NULL, # the 'empty' feature. But computing it here will be fine # for almost all models, only causing potential problems for processing # of splines on new machines (#1465) - basis <- frame_basis(bterms, data = object$data) + bframe_old <- brmsframe(object$formula, data = object$data) + basis <- frame_basis(bframe_old, data = object$data) } bframe <- brmsframe(bterms, data = data, basis = basis) .standata( From 5bb6531e72fca564619f0b8729c6ec430d0e8f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 16 Sep 2024 10:59:54 +0200 Subject: [PATCH 21/21] feature issue #1489 --- DESCRIPTION | 4 +- NAMESPACE | 1 + NEWS.md | 1 + R/brmsformula.R | 18 ++++++--- R/brmsframe.R | 4 +- R/data-response.R | 61 +++++++++++++++++++++++----- R/families.R | 45 +++++++++------------ R/family-lists.R | 2 +- R/formula-ad.R | 12 ++++++ R/prepare_predictions.R | 21 ++++++++-- R/priors.R | 4 ++ R/rename_pars.R | 20 +++++++++ R/stan-likelihood.R | 11 +++-- R/stan-response.R | 72 +++++++++++++++++++++++++-------- man/addition-terms.Rd | 15 ++++--- man/brmsfamily.Rd | 7 +--- tests/local/tests.models-5.R | 7 ++++ tests/testthat/tests.stancode.R | 8 +++- tests/testthat/tests.standata.R | 16 +++++++- 19 files changed, 242 insertions(+), 87 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 2819b8f87..dfa84308f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,8 +2,8 @@ Package: brms Encoding: UTF-8 Type: Package Title: Bayesian Regression Models using 'Stan' -Version: 2.21.8 -Date: 2024-09-12 +Version: 2.21.9 +Date: 2024-09-16 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), diff --git a/NAMESPACE b/NAMESPACE index 484e35628..f2539edac 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -570,6 +570,7 @@ export(read_csv_as_stanfit) export(recompile_model) export(reloo) export(rename_pars) +export(resp_bhaz) export(resp_cat) export(resp_cens) export(resp_dec) diff --git a/NEWS.md b/NEWS.md index 027c9b35c..6362aa10b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ### New Features +* Support stratified `cox` models via the new addition term `bhaz`. (#1489) * Support futures for parallelization in the `cmdstanr` backend. (#1684) * Add method `loo_epred` thanks to Aki Vehtari. (#1641) * Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354) diff --git a/R/brmsformula.R b/R/brmsformula.R index 4a421346b..a627a7803 100644 --- a/R/brmsformula.R +++ b/R/brmsformula.R @@ -1296,12 +1296,6 @@ validate_formula.brmsformula <- function( out$family$thres <- extract_thres_names(out, data) out$family$cats <- extract_cat_names(out, data) } - if (is.mixfamily(out$family)) { - # every mixture family needs to know about response categories - for (i in seq_along(out$family$mix)) { - out$family$mix[[i]]$thres <- out$family$thres - } - } } conv_cats_dpars <- conv_cats_dpars(out$family) if (conv_cats_dpars && !is.null(data)) { @@ -1337,6 +1331,18 @@ validate_formula.brmsformula <- function( out$family$dpars <- union(dp_dpars, out$family$dpars) } } + if (is_cox(out$family) && !is.null(data)) { + # for easy access of baseline hazards + out$family$bhaz <- extract_bhaz(out, data) + } + if (is.mixfamily(out$family)) { + # every mixture family needs to know about additional response information + for (i in seq_along(out$family$mix)) { + for (term in c("cats", "thres", "bhaz")) { + out$family$mix[[i]][[term]] <- out$family[[term]] + } + } + } # incorporate deprecated arguments require_threshold <- is_ordinal(out$family) && is.null(out$family$threshold) diff --git a/R/brmsframe.R b/R/brmsframe.R index 0075cd764..6b23f4c11 100644 --- a/R/brmsframe.R +++ b/R/brmsframe.R @@ -418,8 +418,8 @@ frame_basis_bhaz <- function(x, data, ...) { if (is_cox(x$family)) { # compute basis matrix of the baseline hazard for the Cox model y <- model.response(model.frame(x$respform, data, na.action = na.pass)) - out$basis_matrix <- bhaz_basis_matrix(y, args = x$family$bhaz) + args <- family_info(x, "bhaz")$args + out$basis_matrix <- bhaz_basis_matrix(y, args = args) } out } - diff --git a/R/data-response.R b/R/data-response.R index bd23be014..dfb90a4ac 100644 --- a/R/data-response.R +++ b/R/data-response.R @@ -469,14 +469,34 @@ data_bhaz <- function(bframe, data, data2, prior) { return(out) } y <- bframe$frame$resp$values - args <- bframe$family$bhaz + bhaz <- family_info(bframe, "bhaz") bs <- bframe$basis$bhaz$basis_matrix - out$Zbhaz <- bhaz_basis_matrix(y, args, basis = bs) - out$Zcbhaz <- bhaz_basis_matrix(y, args, integrate = TRUE, basis = bs) + out$Zbhaz <- bhaz_basis_matrix(y, bhaz$args, basis = bs) + out$Zcbhaz <- bhaz_basis_matrix(y, bhaz$args, integrate = TRUE, basis = bs) out$Kbhaz <- NCOL(out$Zbhaz) - sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp) - con_sbhaz <- eval_dirichlet(sbhaz_prior$prior, out$Kbhaz, data2) - out$con_sbhaz <- as.array(con_sbhaz) + groups <- bhaz$groups + if (!is.null(groups)) { + out$ngrbhaz <- length(groups) + gr <- get_ad_values(bframe, "bhaz", "gr", data) + gr <- factor(rename(gr), levels = groups) + out$Jgrbhaz <- match(gr, groups) + out$con_sbhaz <- matrix(nrow = out$ngrbhaz, ncol = out$Kbhaz) + sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp) + sbhaz_prior_global <- subset2(sbhaz_prior, group = "") + con_sbhaz_global <- eval_dirichlet(sbhaz_prior_global$prior, out$Kbhaz, data2) + for (k in seq_along(groups)) { + sbhaz_prior_group <- subset2(sbhaz_prior, group = groups[k]) + if (nzchar(sbhaz_prior_group$prior)) { + out$con_sbhaz[k, ] <- eval_dirichlet(sbhaz_prior_group$prior, out$Kbhaz, data2) + } else { + out$con_sbhaz[k, ] <- con_sbhaz_global + } + } + } else { + sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp) + con_sbhaz <- eval_dirichlet(sbhaz_prior$prior, out$Kbhaz, data2) + out$con_sbhaz <- as.array(con_sbhaz) + } out } @@ -502,9 +522,6 @@ bhaz_basis_matrix <- function(y, args = list(), integrate = FALSE, } stopifnot(is.list(args)) args$x <- y - if (!is.null(args$intercept)) { - args$intercept <- as_one_logical(args$intercept) - } if (is.null(args$Boundary.knots)) { # avoid 'knots' outside 'Boundary.knots' error (#1143) # we also need a smaller lower boundary knot to avoid lp = -Inf @@ -524,6 +541,29 @@ bhaz_basis_matrix <- function(y, args = list(), integrate = FALSE, out } +# extract baseline hazard information from data for storage in the model family +# @return a named list with elements: +# args: arguments that can be passed to bhaz_basis_matrix +# groups: optional names of the groups for which to stratify +extract_bhaz <- function(x, data) { + stopifnot(is.brmsformula(x) || is.brmsterms(x), is_cox(x)) + if (is.null(x$adforms)) { + x$adforms <- terms_ad(x$formula, x$family) + } + out <- list() + if (is.null(x$adforms$bhaz)) { + # bhaz is an optional addition term so defaults need to be listed here too + out$args <- list(df = 5, intercept = TRUE) + } else { + out$args <- eval_rhs(x$adforms$bhaz)$flags + gr <- get_ad_values(x, "bhaz", "gr", data) + if (!is.null(gr)) { + out$groups <- rename(levels(factor(gr))) + } + } + out +} + # extract names of response categories # @param x a brmsterms object or one that can be coerced to it # @param data user specified data @@ -550,7 +590,6 @@ extract_cat_names <- function(x, data) { # @return a data.frame with columns 'thres' and 'group' extract_thres_names <- function(x, data) { stopifnot(is.brmsformula(x) || is.brmsterms(x), has_thres(x)) - if (is.null(x$adforms)) { x$adforms <- terms_ad(x$formula, x$family) } @@ -609,7 +648,7 @@ extract_thres_names <- function(x, data) { data.frame(thres, group, stringsAsFactors = FALSE) } -# extract threshold names from the response values +# extract number of thresholds from the response values # @param formula with the response on the LHS # @param data a data.frame from which to extract responses # @param extra_cat is the first category an extra (hurdle) category? diff --git a/R/families.R b/R/families.R index fa35b8593..0839aa7cd 100644 --- a/R/families.R +++ b/R/families.R @@ -56,7 +56,6 @@ #' category is used as the reference. If \code{NA}, all categories will be #' predicted, which requires strong priors or carefully specified predictor #' terms in order to lead to an identified model. -#' @param bhaz Currently for experimental purposes only. #' #' @details #' Below, we list common use cases for the different families. @@ -199,7 +198,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log", link_alpha = "identity", link_quantile = "logit", threshold = "flexible", - refcat = NULL, bhaz = NULL) { + refcat = NULL) { slink <- substitute(link) .brmsfamily( family, link = link, slink = slink, @@ -212,8 +211,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log", link_ndt = link_ndt, link_bias = link_bias, link_alpha = link_alpha, link_xi = link_xi, link_quantile = link_quantile, - threshold = threshold, refcat = refcat, - bhaz = bhaz + threshold = threshold, refcat = refcat ) } @@ -227,7 +225,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log", # @return an object of 'brmsfamily' which inherits from 'family' .brmsfamily <- function(family, link = NULL, slink = link, threshold = "flexible", - refcat = NULL, bhaz = NULL, ...) { + refcat = NULL, ...) { family <- tolower(as_one_character(family)) aux_links <- list(...) pattern <- c("^normal$", "^zi_", "^hu_") @@ -300,23 +298,6 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log", out$refcat <- as_one_character(refcat, allow_na = allow_na_ref) } } - if (is_cox(out$family)) { - if (!is.null(bhaz)) { - if (!is.list(bhaz)) { - stop2("'bhaz' should be a list.") - } - out$bhaz <- bhaz - } else { - out$bhaz <- list() - } - # set default arguments - if (is.null(out$bhaz$df)) { - out$bhaz$df <- 5L - } - if (is.null(out$bhaz$intercept)) { - out$bhaz$intercept <- TRUE - } - } out } @@ -475,8 +456,8 @@ combine_family_info <- function(x, y, ...) { clb <- !any(ulapply(x[, 1], isFALSE)) cub <- !any(ulapply(x[, 2], isFALSE)) x <- c(clb, cub) - } else if (y == "thres") { - # thresholds are the same across mixture components + } else if (y %in% c("thres", "bhaz")) { + # same across mixture components x <- x[[1]] } x @@ -687,9 +668,9 @@ zero_inflated_asym_laplace <- function(link = "identity", link_sigma = "log", #' @rdname brmsfamily #' @export -cox <- function(link = "log", bhaz = NULL) { +cox <- function(link = "log") { slink <- substitute(link) - .brmsfamily("cox", link = link, bhaz = bhaz) + .brmsfamily("cox", link = link) } #' @rdname brmsfamily @@ -1750,6 +1731,18 @@ has_thres_groups <- function(family) { any(nzchar(groups)) } +# get group names of baseline hazard groups +get_bhaz_groups <- function(family) { + bhaz <- family_info(family, "bhaz") + unique(bhaz$groups) +} + +# has the model group specific baseline hazards? +has_bhaz_groups <- function(family) { + groups <- get_bhaz_groups(family) + any(nzchar(groups)) +} + has_ndt <- function(family) { "ndt" %in% dpar_class(family_info(family, "dpars")) } diff --git a/R/family-lists.R b/R/family-lists.R index 10c3df4d8..e4d53ad9c 100644 --- a/R/family-lists.R +++ b/R/family-lists.R @@ -387,7 +387,7 @@ links = c("log", "identity", "softplus", "squareplus"), dpars = c("mu"), type = "real", ybounds = c(0, Inf), closed = c(TRUE, NA), - ad = c("weights", "subset", "cens", "trunc", "index"), + ad = c("weights", "subset", "cens", "trunc", "index", "bhaz"), include = "fun_cox.stan", specials = c("cox", "sbi_log", "sbi_log_cdf"), normalized = "" diff --git a/R/formula-ad.R b/R/formula-ad.R index 7f8b2dc2c..c02c84b34 100644 --- a/R/formula-ad.R +++ b/R/formula-ad.R @@ -43,6 +43,7 @@ #' @param denom A vector of positive numeric values specifying #' the denominator values from which the response rates are computed. #' @param gr A vector of grouping indicators. +#' @param df Degrees of freedom of baseline hazard splines for Cox models. #' @param ... For \code{resp_vreal}, vectors of real values. #' For \code{resp_vint}, vectors of integer values. In Stan, #' these variables will be named \code{vreal1}, \code{vreal2}, ..., @@ -160,6 +161,17 @@ resp_dec <- function(x) { class_resp_special("dec", call = match.call(), vars = nlist(dec)) } +#' @rdname addition-terms +#' @export +resp_bhaz <- function(gr = NA, df = 5, ...) { + gr <- deparse0(substitute(gr)) + df <- as_one_integer(df) + args <- nlist(df, ...) + # non-power users shouldn't know they can change 'intercept' + args$intercept <- args$intercept %||% TRUE + class_resp_special("bhaz", call = match.call(), vars = nlist(gr), flags = args) +} + #' @rdname addition-terms #' @export resp_cens <- function(x, y2 = NA) { diff --git a/R/prepare_predictions.R b/R/prepare_predictions.R index 0d59eb993..640e06b90 100644 --- a/R/prepare_predictions.R +++ b/R/prepare_predictions.R @@ -846,12 +846,25 @@ prepare_predictions_bhaz <- function(bframe, draws, sdata, ...) { } out <- list() p <- usc(combine_prefix(bframe)) - sbhaz_regex <- paste0("^sbhaz", p) - sbhaz <- prepare_draws(draws, sbhaz_regex, regex = TRUE) Zbhaz <- sdata[[paste0("Zbhaz", p)]] - out$bhaz <- tcrossprod(sbhaz, Zbhaz) Zcbhaz <- sdata[[paste0("Zcbhaz", p)]] - out$cbhaz <- tcrossprod(sbhaz, Zcbhaz) + if (has_bhaz_groups(bframe)) { + groups <- get_bhaz_groups(bframe) + Jgrbhaz <- sdata[[paste0("Jgrbhaz", p)]] + out$bhaz <- out$cbhaz <- matrix(nrow = nrow(draws), ncol = nrow(Zbhaz)) + for (k in seq_along(groups)) { + sbhaz_regex <- paste0("^sbhaz", p, "\\[", groups[k], ",") + sbhaz <- prepare_draws(draws, sbhaz_regex, regex = TRUE) + take <- Jgrbhaz == k + out$bhaz[, take] <- tcrossprod(sbhaz, Zbhaz[take, ]) + out$cbhaz[, take] <- tcrossprod(sbhaz, Zcbhaz[take, ]) + } + } else { + sbhaz_regex <- paste0("^sbhaz", p) + sbhaz <- prepare_draws(draws, sbhaz_regex, regex = TRUE) + out$bhaz <- tcrossprod(sbhaz, Zbhaz) + out$cbhaz <- tcrossprod(sbhaz, Zcbhaz) + } out } diff --git a/R/priors.R b/R/priors.R index 83d7fad62..3edac0210 100644 --- a/R/priors.R +++ b/R/priors.R @@ -760,6 +760,10 @@ prior_bhaz <- function(bframe, ...) { # the scale of sbhaz is not identified when an intercept is part of mu # thus a sum-to-one constraint ensures identification prior <- prior + brmsprior("dirichlet(1)", class = "sbhaz", ls = px) + if (has_bhaz_groups(bframe)) { + groups <- get_bhaz_groups(bframe) + prior <- prior + brmsprior("", class = "sbhaz", ls = px, group = groups) + } prior } diff --git a/R/rename_pars.R b/R/rename_pars.R index 22731e475..98c7dfffe 100644 --- a/R/rename_pars.R +++ b/R/rename_pars.R @@ -86,6 +86,7 @@ rename_predictor.brmsterms <- function(x, ...) { c(out) <- rename_Ymi(x, ...) } c(out) <- rename_thres(x, ...) + c(out) <- rename_bhaz(x, ...) c(out) <- rename_family_cor_pars(x, ...) out } @@ -201,6 +202,25 @@ rename_thres <- function(bframe, pars, ...) { out } +# rename baseline hazard parameters in cox models +rename_bhaz <- function(bframe, pars, ...) { + out <- list() + # renaming is only required if multiple threshold were estimated + if (!has_bhaz_groups(bframe)) { + return(out) + } + px <- check_prefix(bframe) + p <- usc(combine_prefix(px)) + groups <- get_bhaz_groups(bframe) + for (k in seq_along(groups)) { + pos <- grepl(glue("^sbhaz{p}\\[{k},"), pars) + funs <- seq_len(sum(pos)) + bhaz_names <- glue("sbhaz{p}[{groups[k]},{funs}]") + lc(out) <- rlist(pos, bhaz_names) + } + out +} + # helps in renaming global noise free variables # @param meframe data.frame returned by 'frame_me' rename_Xme <- function(bframe, pars, ...) { diff --git a/R/stan-likelihood.R b/R/stan-likelihood.R index adeeb2da0..19a98d8c3 100644 --- a/R/stan-likelihood.R +++ b/R/stan-likelihood.R @@ -116,7 +116,7 @@ stan_log_lik_cens <- function(ll, bterms, threads, normalize, resp = "", ...) { tr <- stan_log_lik_trunc(ll, bterms, resp = resp, threads = threads) tp <- tp() out <- glue( - "// special treatment of censored data\n", + " // special treatment of censored data\n", s, "if (cens{resp}{n} == 0) {{\n", s, "{tp}{w}{ll$dist}_{lpdf}({Y}{resp}{n}{ll$shift} | {ll$args}){tr};\n", s, "}} else if (cens{resp}{n} == 1) {{\n", @@ -168,7 +168,7 @@ stan_log_lik_mix <- function(ll, bterms, mix, ptheta, threads, cens <- eval_rhs(bterms$adforms$cens) s <- wsp(nsp = 4) out <- glue( - "// special treatment of censored data\n", + " // special treatment of censored data\n", s, "if (cens{resp}{n} == 0) {{\n", s, " ps[{mix}] = {theta} + ", "{ll$dist}_{lpdf}({Y}{resp}{n}{ll$shift} | {ll$args}){tr};\n", @@ -714,7 +714,7 @@ stan_log_lik_wiener <- function(bterms, resp = "", mix = "", threads = NULL, } stan_log_lik_beta <- function(bterms, resp = "", mix = "", ...) { - # TODO: check if we still require n when phi is predicted + # TODO: check if we still require n when phi is predicted # and check the same for other families too reqn <- stan_log_lik_adj(bterms) || nzchar(mix) || paste0("phi", mix) %in% names(bterms$dpars) @@ -726,13 +726,12 @@ stan_log_lik_beta <- function(bterms, resp = "", mix = "", ...) { } stan_log_lik_von_mises <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms) || nzchar(mix) + reqn <- stan_log_lik_adj(bterms) || nzchar(mix) p <- stan_log_lik_dpars(bterms, reqn, resp, mix) sdist("von_mises", p$mu, p$kappa) } -stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL, ...) { p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) p$bhaz <- paste0("bhaz", resp, "[n]") p$cbhaz <- paste0("cbhaz", resp, "[n]") diff --git a/R/stan-response.R b/R/stan-response.R index 70e3e3991..0b23b21c5 100644 --- a/R/stan-response.R +++ b/R/stan-response.R @@ -372,6 +372,7 @@ stan_bhaz <- function(bterms, prior, threads, normalize, ...) { px <- check_prefix(bterms) p <- usc(combine_prefix(px)) resp <- usc(px$resp) + n <- stan_nn(threads) slice <- stan_slice(threads) str_add(out$data) <- glue( " // data for flexible baseline functions\n", @@ -379,25 +380,64 @@ stan_bhaz <- function(bterms, prior, threads, normalize, ...) { " // design matrix of the baseline function\n", " matrix[N{resp}, Kbhaz{resp}] Zbhaz{resp};\n", " // design matrix of the cumulative baseline function\n", - " matrix[N{resp}, Kbhaz{resp}] Zcbhaz{resp};\n", - " // a-priori concentration vector of baseline coefficients\n", - " vector[Kbhaz{resp}] con_sbhaz{resp};\n" - ) - str_add(out$par) <- glue( - " simplex[Kbhaz{resp}] sbhaz{resp}; // baseline coefficients\n" - ) - str_add(out$tpar_prior) <- glue( - " lprior += dirichlet_{lpdf}(sbhaz{resp} | con_sbhaz{resp});\n" - ) - str_add(out$model_def) <- glue( - " // compute values of baseline function\n", - " vector[N{resp}] bhaz{resp} = Zbhaz{resp}{slice} * sbhaz{resp};\n", - " // compute values of cumulative baseline function\n", - " vector[N{resp}] cbhaz{resp} = Zcbhaz{resp}{slice} * sbhaz{resp};\n" + " matrix[N{resp}, Kbhaz{resp}] Zcbhaz{resp};\n" ) str_add(out$pll_args) <- glue( - ", data matrix Zbhaz{resp}, data matrix Zcbhaz{resp}, vector sbhaz{resp}" + ", data matrix Zbhaz{resp}, data matrix Zcbhaz{resp}" ) + if (has_bhaz_groups(bterms)) { + # stratified baseline hazards with separate functions per group + str_add(out$data) <- glue( + " // data for stratification of baseline hazards\n", + " int ngrbhaz{resp}; // number of groups\n", + " array[N{resp}] int Jgrbhaz{resp}; // group indices per observation\n", + " // a-priori concentration vector of baseline coefficients\n", + " array[ngrbhaz{resp}] vector[Kbhaz{resp}] con_sbhaz{resp};\n" + ) + str_add(out$par) <- glue( + " // stratified baseline hazard coefficients\n", + " array[ngrbhaz{resp}] simplex[Kbhaz{resp}] sbhaz{resp};\n" + ) + str_add(out$tpar_prior) <- glue( + " for (k in 1:ngrbhaz{resp}) {{\n", + " lprior += dirichlet_{lpdf}(sbhaz{resp}[k] | con_sbhaz{resp}[k]);\n", + " }}\n" + ) + str_add(out$model_def) <- glue( + " // stratified baseline hazard functions\n", + " vector[N{resp}] bhaz{resp};\n", + " vector[N{resp}] cbhaz{resp};\n" + ) + str_add(out$model_comp_basic) <- glue( + " // compute values of stratified baseline hazard functions\n", + " for (n in 1:N{resp}) {{\n", + stan_nn_def(threads), + " bhaz{resp}{n} = Zbhaz{resp}{n} * sbhaz{resp}[Jgrbhaz{resp}{n}];\n", + " cbhaz{resp}{n} = Zcbhaz{resp}{n} * sbhaz{resp}[Jgrbhaz{resp}{n}];\n", + " }}\n" + ) + str_add(out$pll_args) <- glue(", array[] sbhaz{resp}") + } else { + # a single baseline hazard function + str_add(out$data) <- glue( + " // a-priori concentration vector of baseline coefficients\n", + " vector[Kbhaz{resp}] con_sbhaz{resp};\n" + ) + str_add(out$par) <- glue( + " // baseline hazard coefficients\n", + " simplex[Kbhaz{resp}] sbhaz{resp};\n" + ) + str_add(out$tpar_prior) <- glue( + " lprior += dirichlet_{lpdf}(sbhaz{resp} | con_sbhaz{resp});\n" + ) + str_add(out$model_def) <- glue( + " // compute values of baseline function\n", + " vector[N{resp}] bhaz{resp} = Zbhaz{resp}{slice} * sbhaz{resp};\n", + " // compute values of cumulative baseline function\n", + " vector[N{resp}] cbhaz{resp} = Zcbhaz{resp}{slice} * sbhaz{resp};\n" + ) + str_add(out$pll_args) <- glue(", vector sbhaz{resp}") + } out } diff --git a/man/addition-terms.Rd b/man/addition-terms.Rd index fcf59eff2..5fd3dea10 100644 --- a/man/addition-terms.Rd +++ b/man/addition-terms.Rd @@ -21,6 +21,7 @@ \alias{resp_thres} \alias{resp_cat} \alias{resp_dec} +\alias{resp_bhaz} \alias{resp_cens} \alias{resp_trunc} \alias{resp_mi} @@ -43,6 +44,8 @@ resp_cat(x) resp_dec(x) +resp_bhaz(gr = NA, df = 5, ...) + resp_cens(x, y2 = NA) resp_trunc(lb = -Inf, ub = Inf) @@ -83,6 +86,13 @@ so that the average weight equals one. Defaults to \code{FALSE}.} \item{gr}{A vector of grouping indicators.} +\item{df}{Degrees of freedom of baseline hazard splines for Cox models.} + +\item{...}{For \code{resp_vreal}, vectors of real values. +For \code{resp_vint}, vectors of integer values. In Stan, +these variables will be named \code{vreal1}, \code{vreal2}, ..., +and \code{vint1}, \code{vint2}, ..., respectively.} + \item{y2}{A vector specifying the upper bounds in interval censoring. Will be ignored for non-interval censored observations. However, it should NOT be \code{NA} even for non-interval censored observations to @@ -101,11 +111,6 @@ at the same time using the plausible-values-technique.} \item{denom}{A vector of positive numeric values specifying the denominator values from which the response rates are computed.} - -\item{...}{For \code{resp_vreal}, vectors of real values. -For \code{resp_vint}, vectors of integer values. In Stan, -these variables will be named \code{vreal1}, \code{vreal2}, ..., -and \code{vint1}, \code{vint2}, ..., respectively.} } \value{ A list of additional response information to be processed further diff --git a/man/brmsfamily.Rd b/man/brmsfamily.Rd index be01ae0cb..533294306 100644 --- a/man/brmsfamily.Rd +++ b/man/brmsfamily.Rd @@ -62,8 +62,7 @@ brmsfamily( link_alpha = "identity", link_quantile = "logit", threshold = "flexible", - refcat = NULL, - bhaz = NULL + refcat = NULL ) student(link = "identity", link_sigma = "log", link_nu = "logm1") @@ -109,7 +108,7 @@ von_mises(link = "tan_half", link_kappa = "log") asym_laplace(link = "identity", link_sigma = "log", link_quantile = "logit") -cox(link = "log", bhaz = NULL) +cox(link = "log") hurdle_poisson(link = "log", link_hu = "logit") @@ -227,8 +226,6 @@ consecutive thresholds to the same value, and category is used as the reference. If \code{NA}, all categories will be predicted, which requires strong priors or carefully specified predictor terms in order to lead to an identified model.} - -\item{bhaz}{Currently for experimental purposes only.} } \description{ Family objects provide a convenient way to specify the details of the models diff --git a/tests/local/tests.models-5.R b/tests/local/tests.models-5.R index 28ba6bc51..186fb9d37 100644 --- a/tests/local/tests.models-5.R +++ b/tests/local/tests.models-5.R @@ -87,12 +87,19 @@ test_that("Cox models work correctly", { d1 <- simsurv::simsurv(lambdas = 0.1, gammas = 1.5, betas = c(trt = -0.5), x = covs, maxt = 5) d1 <- merge(d1, covs) + d1$g <- sample(c("a", "b"), nrow(d1), TRUE) fit1 <- brm(eventtime | cens(1 - status) ~ 1 + trt, data = d1, family = brmsfamily("cox"), refresh = 0) print(summary(fit1)) expect_range(posterior_summary(fit1)["b_trt", "Estimate"], -0.70, -0.30) expect_range(waic(fit1)$estimates[3, 1], 620, 670) + + fit2 <- brm(eventtime | cens(1 - status) + bhaz(gr = g) ~ 1 + trt, + data = d1, family = brmsfamily("cox"), refresh = 0) + print(summary(fit2)) + expect_true("sbhaz[a,2]" %in% variables(fit2)) + expect_range(waic(fit2)$estimates[3, 1], 620, 670) }) test_that("ordinal model with grouped thresholds works correctly", { diff --git a/tests/testthat/tests.stancode.R b/tests/testthat/tests.stancode.R index 274c81238..441daf028 100644 --- a/tests/testthat/tests.stancode.R +++ b/tests/testthat/tests.stancode.R @@ -1707,7 +1707,8 @@ test_that("Stan code of GEV models is correct", { }) test_that("Stan code of Cox models is correct", { - data <- data.frame(y = rexp(100), ce = sample(0:1, 100, TRUE), x = rnorm(100)) + data <- data.frame(y = rexp(100), ce = sample(0:1, 100, TRUE), + x = rnorm(100), g = sample(1:3, 100, TRUE)) bform <- bf(y | cens(ce) ~ x) scode <- stancode(bform, data, brmsfamily("cox")) expect_match2(scode, "target += cox_log_lpdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);") @@ -1717,6 +1718,11 @@ test_that("Stan code of Cox models is correct", { scode <- stancode(bform, data, brmsfamily("cox", "identity")) expect_match2(scode, "target += cox_lccdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);") + + bform <- bf(y | bhaz(gr = g) ~ x) + scode <- stancode(bform, data, brmsfamily("cox")) + expect_match2(scode, "lprior += dirichlet_lpdf(sbhaz[k] | con_sbhaz[k]);") + expect_match2(scode, "bhaz[n] = Zbhaz[n] * sbhaz[Jgrbhaz[n]];") }) test_that("offsets appear in the Stan code", { diff --git a/tests/testthat/tests.standata.R b/tests/testthat/tests.standata.R index 7ee8be056..c16e3d11f 100644 --- a/tests/testthat/tests.standata.R +++ b/tests/testthat/tests.standata.R @@ -1006,7 +1006,9 @@ test_that("data for multinomial and dirichlet models is correct", { }) test_that("standata handles cox models correctly", { - data <- data.frame(y = rexp(100), x = rnorm(100)) + data <- data.frame(y = rexp(100), x = rnorm(100), + g = sample(1:3, 100, TRUE)) + bform <- bf(y ~ x) bprior <- prior(dirichlet(3), sbhaz) sdata <- standata(bform, data, brmsfamily("cox"), prior = bprior) @@ -1014,9 +1016,19 @@ test_that("standata handles cox models correctly", { expect_equal(dim(sdata$Zcbhaz), c(100, 5)) expect_equal(sdata$con_sbhaz, as.array(rep(3, 5))) - sdata <- standata(bform, data, brmsfamily("cox", bhaz = list(df = 6))) + bform <- bf(y | bhaz(df = 6) ~ x) + sdata <- standata(bform, data, brmsfamily("cox")) expect_equal(dim(sdata$Zbhaz), c(100, 6)) expect_equal(dim(sdata$Zcbhaz), c(100, 6)) + + bform <- bf(y | bhaz(gr = g) ~ x) + bprior <- prior(dirichlet(3), "sbhaz", group = 2) + sdata <- standata(bform, data, family = brmsfamily("cox"), + prior = bprior) + expect_equal(sdata$ngrbhaz, 3) + expect_equivalent(sdata$Jgrbhaz, data$g) + con_mat <- rbind(rep(1, 5), rep(3, 5), rep(1, 5)) + expect_equivalent(sdata$con_sbhaz, con_mat) }) test_that("standata handles addition term 'rate' is correctly", {