From b08b14bc7b14a4e4783d2c137bc6e5db05f3cbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Cs=C3=A1rdi?= Date: Mon, 4 Nov 2024 12:31:36 +0100 Subject: [PATCH] Scan deps from parsnip::set_engine() --- R/scan-deps-queries.R | 8 +++++--- R/scan-deps.R | 43 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/R/scan-deps-queries.R b/R/scan-deps-queries.R index 4fc54f46..03cfec49 100644 --- a/R/scan-deps-queries.R +++ b/R/scan-deps-queries.R @@ -11,7 +11,8 @@ q_library_0 <- function() { "module" "tar_option_set" "glue" - "ggsave"))', + "ggsave" + "set_engine"))', '((call function: (namespace_operator lhs: (identifier) @ns-name @@ -20,7 +21,7 @@ q_library_0 <- function() { ) @dep-code (#any-of? @ns-name "base" "xfun" "pacman" "modules" "import" "box" "targets" "glue" - "ggplot2") + "ggplot2" "parsnip") (#any-of? @fn-name "library" "require" "loadNamespace" "requireNamespace" "pkg_attach" "pkg_attach2" @@ -30,7 +31,8 @@ q_library_0 <- function() { "use" "tar_option_set" "glue" - "ggsave"))' + "ggsave" + "set_engine"))' ), names = rep("q_library_0", 2)) } diff --git a/R/scan-deps.R b/R/scan-deps.R index 6d04e6c7..459c932c 100644 --- a/R/scan-deps.R +++ b/R/scan-deps.R @@ -44,7 +44,8 @@ re_r_dep <- paste0(collapse = "|", c( "box", "tar_option_set", "glue", - "ggsave" + "ggsave", + "set_engine" )) scan_path_deps <- function(path) { @@ -161,7 +162,7 @@ scan_path_deps_do_gen_hits <- function(hits, path) { } ) pkgs <- lapply(seq_along(code), function(i) { - parse_pkg_from_call(ns[i], fn[i], code[i]) + safe_parse_pkg_from_call(ns[i], fn[i], code[i]) }) pkgs_count <- lengths(pkgs) data_frame( @@ -193,6 +194,7 @@ prot_glue_glue <- function( ..., .sep = "", .envir = parent.frame(), .open = "{", .close = "}") { } prot_ggplot2_ggsave <- function(filename, ...) { } +prot_parsnip_set_engine <- function(object, engine, ...) { } safe_parse_pkg_from_call <- function(ns, fn, code) { tryCatch( @@ -219,7 +221,8 @@ parse_pkg_from_call <- function(ns, fn, code) { "use" = prot_box_use, "tar_option_set" = prot_targets_tar_option_set, "glue" = prot_glue_glue, - "ggsave" = prot_ggplot2_ggsave + "ggsave" = prot_ggplot2_ggsave, + "set_engine" = prot_parsnip_set_engine ) matched <- match.call(fun, expr, expand.dots = FALSE) switch(fn, @@ -244,7 +247,9 @@ parse_pkg_from_call <- function(ns, fn, code) { "glue" = parse_pkg_from_call_glue(ns, fn, matched), "ggsave" = - parse_pkg_from_call_ggplot2(ns, fn, matched) + parse_pkg_from_call_ggplot2(ns, fn, matched), + "set_engine" = + parse_pkg_from_call_parsnip(ns, fn, matched) ) } @@ -426,6 +431,36 @@ parse_pkg_from_call_ggplot2 <- function(ns, fn, matched) { NULL } +parse_pkg_from_call_parsnip <- function(ns, fn, matched) { + if (!is.na(ns) && ns != "parsnip") return(NULL) + engine <- matched[["engine"]] + if (!is.character(engine) || length(engine) != 1L) { + return(NULL) + } + + map <- getOption("renv.parsnip.engines", default = list( + glm = "stats", + glmnet = "glmnet", + keras = "keras", + kknn = "kknn", + nnet = "nnet", + rpart = "rpart", + spark = "sparklyr", + stan = "rstanarm" + )) + + pkgs <- if (is.function(map)) { + map(engine) + } else { + map[[engine]] + } + + if (length(pkgs) > 0) { + return(pkgs) + } + NULL +} + # ------------------------------------------------------------------------- scan_path_deps_do_rmd <- function(code, path) {