Skip to content

Commit

Permalink
Merge pull request #1211 from tidymodels/misc-sparsevctrs
Browse files Browse the repository at this point in the history
Misc sparsevctrs
  • Loading branch information
EmilHvitfeldt authored Oct 4, 2024
2 parents ee072ce + 3bf37b5 commit 5ce414e
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 12 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Imports:
prettyunits,
purrr (>= 1.0.0),
rlang (>= 1.1.0),
sparsevctrs (>= 0.1.0.9000),
sparsevctrs (>= 0.1.0.9002),
stats,
tibble (>= 2.1.1),
tidyr (>= 1.3.0),
Expand Down
4 changes: 2 additions & 2 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
}

if (is_sparse_tibble(data)) {
if (sparsevctrs::has_sparse_elements(data)) {
cli::cli_abort(
"Sparse data cannot be used with formula interface. Please use
{.fn fit_xy} instead."
Expand Down Expand Up @@ -417,7 +417,7 @@ maybe_sparse_matrix <- function(x) {
return(x)
}

if (is_sparse_tibble(x)) {
if (sparsevctrs::has_sparse_elements(x)) {
res <- sparsevctrs::coerce_to_sparse_matrix(x)
} else {
res <- as.matrix(x)
Expand Down
2 changes: 1 addition & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ fit.model_spec <-


if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
data <- sparsevctrs::coerce_to_sparse_tibble(data, rlang::caller_env(0))
}

dots <- quos(...)
Expand Down
2 changes: 1 addition & 1 deletion R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ form_xy <- function(object, control, env,
remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept)
allow_sparse_x <- encoding_info %>% dplyr::pull(allow_sparse_x)

if (allow_sparse_x && is_sparse_tibble(env$data)) {
if (allow_sparse_x && sparsevctrs::has_sparse_elements(env$data)) {
target <- "dgCMatrix"
}

Expand Down
2 changes: 1 addition & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ prepare_data <- function(object, new_data) {
if (allow_sparse(object) && inherits(new_data, "dgCMatrix")) {
return(new_data)
}
if (allow_sparse(object) && is_sparse_tibble(new_data)) {
if (allow_sparse(object) && sparsevctrs::has_sparse_elements(new_data)) {
new_data <- sparsevctrs::coerce_to_sparse_matrix(new_data)
return(new_data)
}
Expand Down
6 changes: 1 addition & 5 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,12 @@ to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) {
x
}

is_sparse_tibble <- function(x) {
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
}

is_sparse_matrix <- function(x) {
methods::is(x, "sparseMatrix")
}

materialize_sparse_tibble <- function(x, object, input) {
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
if (sparsevctrs::has_sparse_elements(x) && (!allow_sparse(object))) {
if (inherits(object, "model_fit")) {
object <- object$spec
}
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

---

Code
fit(spec, avg_price_per_room ~ ., data = hotel_data)
Condition
Error in `fit()`:
! `x` must have column names.

# sparse tibble can be passed to `fit_xy() - unsupported

Code
Expand Down
16 changes: 15 additions & 1 deletion tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ test_that("maybe_sparse_matrix() is used correctly", {

local_mocked_bindings(
maybe_sparse_matrix = function(x) {
if (is_sparse_tibble(x)) {
if (sparsevctrs::has_sparse_elements(x)) {
stop("sparse vectors detected")
} else {
stop("no sparse vectors detected")
Expand Down Expand Up @@ -313,3 +313,17 @@ test_that("maybe_sparse_matrix() is used correctly", {
fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[, 1])
)
})

test_that("fit() errors if sparse matrix has no colnames", {
hotel_data <- sparse_hotel_rates()
colnames(hotel_data) <- NULL

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_snapshot(
error = TRUE,
fit(spec, avg_price_per_room ~ ., data = hotel_data)
)
})

0 comments on commit 5ce414e

Please sign in to comment.