diff --git a/R/plot_trace.R b/R/plot_trace.R index 21661b9..ecc6c57 100644 --- a/R/plot_trace.R +++ b/R/plot_trace.R @@ -21,16 +21,16 @@ #' plot_trace(imp) #' #' # plot trace lines for specific columns by supplying a string or character vector -#' plot_trace(imp, "bmi") -#' plot_trace(imp, c("bmi", "hyp")) +#' plot_trace(imp, "chl") +#' plot_trace(imp, c("chl", "hyp")) #' # plot trace lines for specific columns by supplying unquoted variable names -#' plot_trace(imp, bmi) -#' plot_trace(imp, c(bmi, hyp)) +#' plot_trace(imp, chl) +#' plot_trace(imp, c(chl, hyp)) #' #' # plot trace lines for specific columns by passing an object with variable names #' # from the environment, unquoted with `!!` -#' my_variables <- c("bmi", "hyp") +#' my_variables <- c("chl", "hyp") #' plot_trace(imp, !!my_variables) #' # object with variable names must be unquoted with `!!` #' try(plot_trace(imp, my_variables)) @@ -41,15 +41,12 @@ plot_trace <- function(data, vrb = "all") { if (is.null(data$chainMean) && is.null(data$chainVar)) { cli::cli_abort("No convergence diagnostics found", call. = FALSE) } - - # extract chain means and chain standard deviations - mn <- data$chainMean - sm <- sqrt(data$chainVar) - - # select variable to plot from list of imputed variables vrb <- rlang::enexpr(vrb) vrbs_in_data <- names(data$imp) vrb_matched <- match_vrb(vrb, vrbs_in_data) + # extract chain means and chain standard deviations + mn <- data$chainMean + sm <- sqrt(data$chainVar) available_vrbs <- vrbs_in_data[apply(!(is.nan(mn) | is.na(sm)), 1, all)] if (any(vrb_matched %nin% available_vrbs)) { cli::cli_inform( @@ -61,6 +58,7 @@ plot_trace <- function(data, vrb = "all") { ) } vrb_matched <- vrb_matched[which(vrb_matched %in% available_vrbs)] + # compute diagnostics p <- length(vrb_matched) m <- data$m it <- data$iteration @@ -77,8 +75,7 @@ plot_trace <- function(data, vrb = "all") { ) ) ) - - # plot the convergence diagnostics + # create plot ggplot2::ggplot(long, ggplot2::aes( x = .data$.it,