Skip to content

Commit

Permalink
match documentation other plot_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
hanneoberman committed Jul 26, 2024
1 parent dd4b9d2 commit 734d3d0
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions R/plot_trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
#' plot_trace(imp)
#'
#' # plot trace lines for specific columns by supplying a string or character vector
#' plot_trace(imp, "bmi")
#' plot_trace(imp, c("bmi", "hyp"))
#' plot_trace(imp, "chl")
#' plot_trace(imp, c("chl", "hyp"))

#' # plot trace lines for specific columns by supplying unquoted variable names
#' plot_trace(imp, bmi)
#' plot_trace(imp, c(bmi, hyp))
#' plot_trace(imp, chl)
#' plot_trace(imp, c(chl, hyp))
#'
#' # plot trace lines for specific columns by passing an object with variable names
#' # from the environment, unquoted with `!!`
#' my_variables <- c("bmi", "hyp")
#' my_variables <- c("chl", "hyp")
#' plot_trace(imp, !!my_variables)
#' # object with variable names must be unquoted with `!!`
#' try(plot_trace(imp, my_variables))
Expand All @@ -41,15 +41,12 @@ plot_trace <- function(data, vrb = "all") {
if (is.null(data$chainMean) && is.null(data$chainVar)) {
cli::cli_abort("No convergence diagnostics found", call. = FALSE)
}

# extract chain means and chain standard deviations
mn <- data$chainMean
sm <- sqrt(data$chainVar)

# select variable to plot from list of imputed variables
vrb <- rlang::enexpr(vrb)
vrbs_in_data <- names(data$imp)
vrb_matched <- match_vrb(vrb, vrbs_in_data)
# extract chain means and chain standard deviations
mn <- data$chainMean
sm <- sqrt(data$chainVar)
available_vrbs <- vrbs_in_data[apply(!(is.nan(mn) | is.na(sm)), 1, all)]
if (any(vrb_matched %nin% available_vrbs)) {
cli::cli_inform(
Expand All @@ -61,6 +58,7 @@ plot_trace <- function(data, vrb = "all") {
)
}
vrb_matched <- vrb_matched[which(vrb_matched %in% available_vrbs)]
# compute diagnostics
p <- length(vrb_matched)
m <- data$m
it <- data$iteration
Expand All @@ -77,8 +75,7 @@ plot_trace <- function(data, vrb = "all") {
)
)
)

# plot the convergence diagnostics
# create plot
ggplot2::ggplot(long,
ggplot2::aes(
x = .data$.it,
Expand Down

0 comments on commit 734d3d0

Please sign in to comment.