Skip to content

Commit

Permalink
Add print method for sdmTMB_cv() #319
Browse files Browse the repository at this point in the history
  • Loading branch information
seananderson committed Sep 23, 2024
1 parent 796eb08 commit ed977c7
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: sdmTMB
Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB'
Version: 0.6.0.9009
Version: 0.6.0.9010
Authors@R: c(
person(c("Sean", "C."), "Anderson", , "[email protected]",
role = c("aut", "cre"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ S3method(nobs,sdmTMB)
S3method(plot,sdmTMBmesh)
S3method(predict,sdmTMB)
S3method(print,sdmTMB)
S3method(print,sdmTMB_cv)
S3method(ranef,sdmTMB)
S3method(residuals,sdmTMB)
S3method(simulate,sdmTMB)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# sdmTMB (development version)

* Add print method for `sdmTMB_cv()` output. #319

* Add progress bar to `simulate.sdmTMB()`. #346

* Add AUC and TSS examples to cross validation vignette. #268
Expand Down
28 changes: 27 additions & 1 deletion R/cross-val.R
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ sdmTMB_cv <- function(
pdHess <- vapply(out, `[[`, "pdHess", FUN.VALUE = logical(1L))
max_grad <- vapply(out, `[[`, "max_gradient", FUN.VALUE = numeric(1L))
converged <- all(pdHess)
list(
out <- list(
data = data,
models = models,
fold_loglik = fold_cv_ll,
Expand All @@ -462,9 +462,35 @@ sdmTMB_cv <- function(
pdHess = pdHess,
max_gradients = max_grad
)
`class<-`(out, "sdmTMB_cv")
}

log_sum_exp <- function(x) {
max_x <- max(x)
max_x + log(sum(exp(x - max_x)))
}

#' @export
#' @import methods
print.sdmTMB_cv <- function(x, ...) {
nmods <- length(x$models)
nconverged <- sum(x$converged)
cat(paste0("Cross validation of sdmTMB models with ", nmods, " folds.\n"))
cat("\n")
cat("Summary of the first fold model fit:\n")
cat("\n")
print(x$models[[1]])
cat("\n")
cat("Access the rest of the models in a list element named `models`.\n")
cat("E.g. `object$models[[2]]` for the 2nd fold model fit.\n")
cat("\n")
cat(paste0(nconverged, " out of ", nmods, " models are consistent with convergence.\n"))
cat("Figure out which folds these are in the `converged` list element.\n")
cat("\n")
cat(paste0("Out-of-sample log likelihood for each fold: ", paste(round(x$fold_loglik, 2), collapse = ", "), ".\n"))
cat("Access these values in the `fold_loglik` list element.\n")
cat("\n")
cat("Sum of out-of-sample log likelihoods:", round(x$sum_loglik, 2), "\n")
cat("More positive values imply better out-of-sample prediction.\n")
cat("Access this value in the `sum_loglik` list element.\n")
}
1 change: 1 addition & 0 deletions tests/testthat/test-cross-validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ test_that("Basic cross validation works", {
data = d, mesh = spde,
family = tweedie(link = "log"), time = "year", k_folds = 2
)
print(x)
expect_equal(class(x$sum_loglik), "numeric")
expect_equal(x$sum_loglik, sum(x$data$cv_loglik))
expect_equal(x$sum_loglik, sum(x$fold_loglik))
Expand Down

0 comments on commit ed977c7

Please sign in to comment.