Skip to content

Commit

Permalink
Important bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
NicChr committed Nov 20, 2024
1 parent f997146 commit 97b731f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
42 changes: 26 additions & 16 deletions R/f_summarise.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,32 +220,42 @@ f_summarise <- function(data, ..., .by = NULL,
fns <- rlang::as_label(across_fns)
fn_names <- fns
}
fn_matches <- cheapr::na_rm(sort(c(match(fns, base_fns),
match(fns, collapse_fns))))
which_fns <- which(fns %in% base_fns | fns %in% collapse_fns)
which_other_fns <- which(fns %in% base_fns | fns %in% collapse_fns,
invert = TRUE)
fast_fn_names <- collapse_fns[fn_matches]
base_matches <- match(fns, base_fns)
collapse_matches <- match(fns, collapse_fns)
fast_fns <- rep_len(NA_character_, length(fns))
for (i in seq_along(fns)){
if (!is.na(base_matches[i])){
fast_fns[i] <- collapse_fns[base_matches[i]]
} else if (!is.na(collapse_matches[i])){
fast_fns[i] <- collapse_fns[collapse_matches[i]]
}
}
which_fast <- cheapr::na_find(fast_fns, invert = TRUE)
which_other <- cheapr::na_find(fast_fns)
fast_fns <- cheapr::na_rm(fast_fns)

full_res <- vector("list", length(vars) * length(fns))
col_matrix <- matrix(logical( length(vars) * length(fns)),
nrow = length(vars),
ncol = length(fns))
col_matrix[, which_fns] <- TRUE
across_res <- fast_eval_across(temp_data, groups, vars, fast_fn_names, dot_env)
nrow = length(fns),
ncol = length(vars))
for (col in seq_along(vars)){
col_matrix[which_fast, col] <- TRUE
}
across_res <- fast_eval_across(temp_data, groups, vars, fast_fns, dot_env)

if (length(which_other_fns) > 0){
if (length(which_other) > 0){
if (!grouped_df_has_been_constructed){
temp_data <- construct_grouped_df(temp_data, groups, group_vars)
}
grouped_df_has_been_constructed <- TRUE
if (across_fns_as_list){
dplyr_res <- dplyr::summarise(
temp_data, dplyr::across(
dplyr::all_of(vars),
rlang::eval_tidy(across_fns, env = dot_env)[which_other_fns]
rlang::eval_tidy(across_fns, env = dot_env)[which_other]
), .groups = "drop"
)
} else {
if (!grouped_df_has_been_constructed){
temp_data <- construct_grouped_df(temp_data, groups, group_vars)
}
grouped_df_has_been_constructed <- TRUE
dplyr_res <- dplyr::summarise(
temp_data, dplyr::across(
dplyr::all_of(vars),
Expand Down
51 changes: 51 additions & 0 deletions tests/testthat/test-f_summarise.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,55 @@ test_that("summarise", {
),
target
)

# 2 variables and a mix of optimised/non-optimised calls

target <- airquality %>%
dplyr::summarise(
dplyr::across(
dplyr::all_of(c("Wind", "Temp")),
list(mean = function(x) mean(x, na.rm = TRUE),
first = function(x) x[1],
min = function(x) min(x, na.rm = TRUE),
last_obs = function(x) x[length(x)],
max = function(x) max(x, na.rm = TRUE))
), N = dplyr::n()
)

expect_equal(
airquality %>%
f_summarise(
dplyr::across(dplyr::all_of(c("Wind", "Temp")),
list(mean, first = function(x) x[1], min, last_obs = function(x) x[length(x)], max)),
N = dplyr::n()
),
target
)

# 2 variables and a mix of optimised/non-optimised calls, and groups

target <- airquality %>%
dplyr::summarise(
dplyr::across(
dplyr::all_of(c("Wind", "Temp")),
list(mean = function(x) mean(x, na.rm = TRUE),
first = function(x) x[1],
min = function(x) min(x, na.rm = TRUE),
last_obs = function(x) x[length(x)],
max = function(x) max(x, na.rm = TRUE))
), N = dplyr::n(),
.by = Month
)

expect_equal(
airquality %>%
f_summarise(
dplyr::across(dplyr::all_of(c("Wind", "Temp")),
list(mean, first = function(x) x[1], min, last_obs = function(x) x[length(x)], max)),
N = dplyr::n(),
.by = Month,
.order = FALSE
),
target
)
})

0 comments on commit 97b731f

Please sign in to comment.