Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lprior argument to prior, enabling separate lprior variables #1724

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions R/priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@
#' @export
set_prior <- function(prior, class = "b", coef = "", group = "",
resp = "", dpar = "", nlpar = "",
lb = NA, ub = NA, check = TRUE) {
input <- nlist(prior, class, coef, group, resp, dpar, nlpar, lb, ub, check)
lb = NA, ub = NA, lprior = "", check = TRUE) {
input <- nlist(prior, class, coef, group, resp, dpar, nlpar, lb, ub, lprior, check)
input <- try(as.data.frame(input), silent = TRUE)
if (is_try_error(input)) {
stop2("Processing arguments of 'set_prior' has failed:\n", input)
Expand All @@ -375,7 +375,7 @@ set_prior <- function(prior, class = "b", coef = "", group = "",

# validate arguments passed to 'set_prior'
.set_prior <- function(prior, class, coef, group, resp,
dpar, nlpar, lb, ub, check) {
dpar, nlpar, lb, ub, lprior, check) {
prior <- as_one_character(prior)
class <- as_one_character(class)
group <- as_one_character(group)
Expand All @@ -386,16 +386,17 @@ set_prior <- function(prior, class = "b", coef = "", group = "",
check <- as_one_logical(check)
lb <- as_one_character(lb, allow_na = TRUE)
ub <- as_one_character(ub, allow_na = TRUE)
lprior <- as_one_character(lprior)
if (dpar == "mu") {
# distributional parameter 'mu' is currently implicit #1368
dpar <- ""
}
if (!check) {
# prior will be added to the log-posterior as is
class <- coef <- group <- resp <- dpar <- nlpar <- lb <- ub <- ""
class <- coef <- group <- resp <- dpar <- nlpar <- lb <- ub <- lprior <- ""
}
source <- "user"
out <- nlist(prior, source, class, coef, group, resp, dpar, nlpar, lb, ub)
out <- nlist(prior, source, class, coef, group, resp, dpar, nlpar, lb, ub, lprior)
do_call(brmsprior, out)
}

Expand Down Expand Up @@ -558,7 +559,7 @@ default_prior.default <- function(object, data, family = gaussian(), autocor = N
# explicitly label default priors as such
prior$source <- "default"
# apply 'unique' as the same prior may have been included multiple times
to_order <- with(prior, order(resp, dpar, nlpar, class, group, coef))
to_order <- with(prior, order(resp, dpar, nlpar, class, group, coef, lprior))
prior <- unique(prior[to_order, , drop = FALSE])
rownames(prior) <- NULL
class(prior) <- c("brmsprior", "data.frame")
Expand Down Expand Up @@ -1565,7 +1566,7 @@ get_sample_prior <- function(prior) {
# create data.frames containing prior information
brmsprior <- function(prior = "", class = "", coef = "", group = "",
resp = "", dpar = "", nlpar = "", lb = "", ub = "",
source = "", ls = list()) {
lprior = "", source = "", ls = list()) {
if (length(ls)) {
if (is.null(names(ls))) {
stop("Argument 'ls' must be named.")
Expand All @@ -1580,7 +1581,7 @@ brmsprior <- function(prior = "", class = "", coef = "", group = "",
}
out <- data.frame(
prior, class, coef, group,
resp, dpar, nlpar, lb, ub, source,
resp, dpar, nlpar, lb, ub, lprior, source,
stringsAsFactors = FALSE
)
class(out) <- c("brmsprior", "data.frame")
Expand All @@ -1594,7 +1595,7 @@ empty_prior <- function() {
brmsprior(
prior = char0, source = char0, class = char0,
coef = char0, group = char0, resp = char0,
dpar = char0, nlpar = char0, lb = char0, ub = char0
dpar = char0, nlpar = char0, lb = char0, ub = char0, lprior = char0
)
}

Expand Down Expand Up @@ -1623,7 +1624,7 @@ prior_bounds <- function(prior) {
# all columns of brmsprior objects
all_cols_prior <- function() {
c("prior", "class", "coef", "group", "resp",
"dpar", "nlpar", "lb", "ub", "source")
"dpar", "nlpar", "lb", "ub", "lprior", "source")
}

# relevant columns for duplication checks in brmsprior objects
Expand Down Expand Up @@ -1915,7 +1916,7 @@ as.brmsprior <- function(x) {

defaults <- c(
class = "b", coef = "", group = "", resp = "",
dpar = "", nlpar = "", lb = NA, ub = NA
dpar = "", nlpar = "", lb = NA, ub = NA, lprior = ""
)
for (v in names(defaults)) {
if (!v %in% names(x)) {
Expand Down
17 changes: 14 additions & 3 deletions R/stan-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
c(index) <- j
}
prior_ij <- subset2(prior, coef = coef[i, j])
lprior_tag <- prior_ij$lprior
if (NROW(px) > 1L) {
# disambiguate priors of coefficients with the same name
# coming from different model components
Expand Down Expand Up @@ -131,7 +132,13 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
coef_prior, par_ij, broadcast = broadcast,
bound = bound, resp = px$resp[1], normalize = normalize
)
# add to the lprior
str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n")
# add to the lprior of the tag if specified
if (!is.null(lprior_tag) && lprior_tag != "") {
str_add(out$tpar_prior) <- paste0(lpp(tag = lprior_tag), coef_prior, ";\n")
}

}
}
}
Expand Down Expand Up @@ -241,7 +248,7 @@ stan_base_prior <- function(prior, col = "prior", sel_prior = NULL, ...) {
return(brmsprior()[, col])
}
}
vars <- c("group", "nlpar", "dpar", "resp", "class")
vars <- c("group", "nlpar", "dpar", "resp", "class", "lprior")
for (v in vars) {
take <- nzchar(prior[[v]])
if (any(take)) {
Expand Down Expand Up @@ -698,7 +705,11 @@ stopif_prior_bound <- function(prior, class, ...) {
}

# lprior plus equal
lpp <- function(wsp = 2) {
lpp <- function(wsp = 2, tag = NULL) {
wsp <- collapse(rep(" ", wsp))
paste0(wsp, "lprior += ")
if (is.null(tag)) {
paste0(wsp, "lprior", " += ")
} else {
paste0(wsp, "lprior_", tag, " += ")
}
}
5 changes: 5 additions & 0 deletions R/stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ stancode.default <- function(object, data, family = gaussian(),
backend = getOption("brms.backend", "rstan"),
silent = TRUE, save_model = NULL, ...) {

lprior_tags <- prior$lprior[prior$lprior != ""]

normalize <- as_one_logical(normalize)
parse <- as_one_logical(parse)
backend <- match.arg(backend, backend_choices())
Expand Down Expand Up @@ -278,12 +280,15 @@ stancode.default <- function(object, data, family = gaussian(),

# generate transformed parameters block
scode_lprior_def <- " real lprior = 0; // prior contributions to the log posterior\n"
scode_lprior_tags_def <- paste0(
" real lprior_", unique(lprior_tags), " = 0;\n", collapse = "")
scode_transformed_parameters <- paste0(
"transformed parameters {\n",
scode_predictor[["tpar_def"]],
scode_re[["tpar_def"]],
scode_Xme[["tpar_def"]],
str_if(normalize, scode_lprior_def),
str_if(normalize, scode_lprior_tags_def),
collapse_stanvars(stanvars, "tparameters", "start"),
scode_predictor[["tpar_prior_const"]],
scode_re[["tpar_prior_const"]],
Expand Down
Loading