Skip to content

Commit

Permalink
Merge branch 'master' into re-predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 19, 2024
2 parents cda53a5 + c0eb374 commit d728ae3
Show file tree
Hide file tree
Showing 29 changed files with 956 additions and 586 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ tests/local/models_0.10.0.Rda
tests/local/models_1.2.0.Rda
tests/local/Rplots.pdf
/Meta/
.DS_Store
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.21.9
Date: 2024-09-16
Version: 2.21.11
Date: 2024-09-19
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,6 @@ S3method(restructure,brmsfit)
S3method(rhat,brmsfit)
S3method(shinystan::launch_shinystan,brmsfit)
S3method(stan_log_lik,brmsterms)
S3method(stan_log_lik,family)
S3method(stan_log_lik,mixfamily)
S3method(stan_log_lik,mvbrmsterms)
S3method(stan_predictor,bframel)
S3method(stan_predictor,bframenl)
Expand Down
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

### New Features

* Use group-level coefficients as predictors in other formulas via `re` terms.
* Support different Gaussian process kernels in `gp` terms. (#234)
* Support stratified `cox` models via the new addition term `bhaz`. (#1489)
* Support futures for parallelization in the `cmdstanr` backend. (#1684)
* Add method `loo_epred` thanks to Aki Vehtari. (#1641)
* Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354)
* Add priorsense support via `create_priorsense_data.brmsfit`
thanks to Noa Kallioinen. (#1354)
* Vectorize censored log likelihoods in the Stan code when possible. (#1657)

### Bug Fixes

Expand Down
32 changes: 23 additions & 9 deletions R/brmsterms.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
y$cov_ranef <- x$cov_ranef
class(y) <- "brmsterms"

y$resp <- ""
if (check_response) {
# extract response variables
y$respform <- validate_resp_formula(formula, empty_ok = FALSE)
if (mv) {
y$resp <- terms_resp(y$respform)
} else {
y$resp <- ""
}
}

Expand All @@ -97,6 +96,11 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
x$pforms[[dp]] <- combine_formulas(formula, x$pforms[[dp]], dp)
}
x$pforms <- move2start(x$pforms, mu_dpars)
for (i in seq_along(family$mix)) {
# store the respective mixture index in each mixture component
# this enables them to be easily passed along, e.g. in stan_log_lik
y$family$mix[[i]]$mix <- i
}
} else if (conv_cats_dpars(x$family)) {
mu_dpars <- str_subset(x$family$dpars, "^mu")
for (dp in mu_dpars) {
Expand All @@ -109,21 +113,20 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
}

# predicted distributional parameters
resp <- ifelse(mv && !is.null(y$resp), y$resp, "")
dpars <- intersect(names(x$pforms), valid_dpars(family))
dpar_forms <- x$pforms[dpars]
nlpars <- setdiff(names(x$pforms), dpars)

y$dpars <- named_list(dpars)
for (dp in dpars) {
if (get_nl(dpar_forms[[dp]])) {
y$dpars[[dp]] <- terms_nlf(dpar_forms[[dp]], nlpars, resp)
y$dpars[[dp]] <- terms_nlf(dpar_forms[[dp]], nlpars, y$resp)
} else {
y$dpars[[dp]] <- terms_lf(dpar_forms[[dp]])
}
y$dpars[[dp]]$family <- dpar_family(family, dp)
y$dpars[[dp]]$dpar <- dp
y$dpars[[dp]]$resp <- resp
y$dpars[[dp]]$resp <- y$resp
if (dpar_class(dp) == "mu") {
y$dpars[[dp]]$respform <- y$respform
y$dpars[[dp]]$adforms <- y$adforms
Expand All @@ -142,12 +145,12 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
attr(nlpar_forms[[nlp]], "center") <- FALSE
}
if (get_nl(nlpar_forms[[nlp]])) {
y$nlpars[[nlp]] <- terms_nlf(nlpar_forms[[nlp]], nlpars, resp)
y$nlpars[[nlp]] <- terms_nlf(nlpar_forms[[nlp]], nlpars, y$resp)
} else {
y$nlpars[[nlp]] <- terms_lf(nlpar_forms[[nlp]])
}
y$nlpars[[nlp]]$nlpar <- nlp
y$nlpars[[nlp]]$resp <- resp
y$nlpars[[nlp]]$resp <- y$resp
check_cs(y$nlpars[[nlp]])
}
used_nlpars <- ufrom_list(c(y$dpars, y$nlpars), "used_nlpars")
Expand Down Expand Up @@ -591,20 +594,31 @@ is.btnl <- function(x) {
inherits(x, "btnl")
}

# figure out if a certain distributional parameter is predicted
is_pred_dpar <- function(bterms, dpar) {
stopifnot(is.brmsterms(bterms))
if (!length(dpar)) {
return(FALSE)
}
mix <- get_mix_id(bterms)
any(paste0(dpar, mix) %in% names(bterms$dpars))
}

# transform mvbrmsterms objects for use in stan_llh.brmsterms
as.brmsterms <- function(x) {
stopifnot(is.mvbrmsterms(x), x$rescor)
families <- ulapply(x$terms, function(y) y$family$family)
stopifnot(all(families == families[1]))
out <- structure(list(), class = "brmsterms")
out$family <- structure(
list(family = paste0(families[1], "_mv"), link = "identity"),
list(family = families[1], link = "identity"),
class = c("brmsfamily", "family")
)
out$family$fun <- paste0(out$family$family, "_mv")
info <- get(paste0(".family_", families[1]))()
out$family[names(info)] <- info
out$sigma_pred <- any(ulapply(x$terms,
function(x) "sigma" %in% names(x$dpar) || is.formula(x$adforms$se)
function(x) is_pred_dpar(x, "sigma") || has_ad_terms(x, "se")
))
weight_forms <- rmNULL(lapply(x$terms, function(x) x$adforms$weights))
if (length(weight_forms)) {
Expand Down
3 changes: 2 additions & 1 deletion R/conditional_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,8 @@ get_int_vars.mvbrmsterms <- function(x, ...) {

#' @export
get_int_vars.brmsterms <- function(x, ...) {
advars <- ulapply(rmNULL(x$adforms[c("trials", "thres", "vint")]), all_vars)
adterms <- c("trials", "thres", "vint")
advars <- ulapply(rmNULL(x$adforms[adterms]), all_vars)
unique(c(advars, get_sp_vars(x, "mo")))
}

Expand Down
4 changes: 2 additions & 2 deletions R/data-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ data_gp <- function(bframe, data, internal = FALSE, ...) {
XgpL <- matrix(nrow = NROW(Xgp), ncol = NROW(Ks))
slambda <- matrix(nrow = NROW(Ks), ncol = D)
for (m in seq_rows(Ks)) {
XgpL[, m] <- eigen_fun_cov_exp_quad(Xgp, m = Ks[m, ], L = L)
slambda[m, ] <- sqrt(eigen_val_cov_exp_quad(m = Ks[m, ], L = L))
XgpL[, m] <- eigen_fun_laplacian(Xgp, m = Ks[m, ], L = L)
slambda[m, ] <- sqrt(eigen_val_laplacian(m = Ks[m, ], L = L))
}
out[[paste0("Xgp", sfx)]] <- XgpL
out[[paste0("slambda", sfx)]] <- slambda
Expand Down
1 change: 1 addition & 0 deletions R/data-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ data_response.brmsframe <- function(x, data, check_response = TRUE,
}

# data for addition arguments of the response
# TODO: replace is.formula(x$adforms$term) pattern with has_ad_terms()
if (has_trials(x$family) || is.formula(x$adforms$trials)) {
if (!length(x$adforms$trials)) {
stop2("Specifying 'trials' is required for this model.")
Expand Down
18 changes: 11 additions & 7 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ combine_family_info <- function(x, y, ...) {
y <- as_one_character(y)
unite <- c(
"dpars", "type", "specials", "include",
"const", "cats", "ad", "normalized"
"const", "cats", "ad", "normalized", "mix"
)
if (y %in% c("family", "link")) {
x <- unlist(x)
Expand Down Expand Up @@ -1785,6 +1785,11 @@ no_nu <- function(bterms) {
isTRUE(bterms$rescor) && "student" %in% family_names(bterms)
}

# get mixture index if specified
get_mix_id <- function(family) {
family_info(family, "mix") %||% ""
}

# does the family-link combination have a built-in Stan function?
has_built_in_fun <- function(family, link = NULL, dpar = NULL, cdf = FALSE) {
link <- link %||% family$link
Expand All @@ -1802,19 +1807,18 @@ prepare_family <- function(x) {
stopifnot(is.brmsformula(x) || is.brmsterms(x))
family <- x$family
acframe <- frame_ac(x)
family$fun <- family[["fun"]] %||% family$family
if (use_ac_cov_time(acframe) && has_natural_residuals(x)) {
family$fun <- paste0(family$family, "_time")
family$fun <- paste0(family$fun, "_time")
} else if (has_ac_class(acframe, "sar")) {
acframe_sar <- subset2(acframe, class = "sar")
if (has_ac_subset(acframe_sar, type = "lag")) {
family$fun <- paste0(family$family, "_lagsar")
family$fun <- paste0(family$fun, "_lagsar")
} else if (has_ac_subset(acframe_sar, type = "error")) {
family$fun <- paste0(family$family, "_errorsar")
family$fun <- paste0(family$fun, "_errorsar")
}
} else if (has_ac_class(acframe, "fcor")) {
family$fun <- paste0(family$family, "_fcor")
} else {
family$fun <- family$family
family$fun <- paste0(family$fun, "_fcor")
}
family
}
Expand Down
16 changes: 10 additions & 6 deletions R/formula-ad.R
Original file line number Diff line number Diff line change
Expand Up @@ -376,21 +376,25 @@ trunc_bounds <- function(bterms, data = NULL, incl_family = FALSE,
out
}

# check if addition argument 'subset' ist used in the model
# check if addition argument 'subset' is used in the model
# works for both univariate and multivariate models
has_subset <- function(bterms) {
.has_subset <- function(x) {
is.formula(x$adforms$subset)
}
if (is.brmsterms(bterms)) {
out <- .has_subset(bterms)
out <- has_ad_terms(bterms, "subset")
} else if (is.mvbrmsterms(bterms)) {
out <- any(ulapply(bterms$terms, .has_subset))
out <- any(ulapply(bterms$terms, has_ad_terms, "subset"))
} else {
out <- FALSE
}
out
}

# check if a model has certain addition terms
has_ad_terms <- function(bterms, terms) {
stopifnot(is.brmsterms(bterms), is.character(terms))
any(ulapply(bterms$adforms[terms], is.formula))
}

# construct a list of indices for cross-formula referencing
frame_index <- function(x, data) {
out <- .frame_index(x, data)
Expand Down
Loading

0 comments on commit d728ae3

Please sign in to comment.