diff --git a/DESCRIPTION b/DESCRIPTION index 56d1328a..a18063a1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,3 +68,5 @@ Collate: 'module.R' 'reexports.R' 'serialization.R' +Remotes: + mlverse/torch diff --git a/NAMESPACE b/NAMESPACE index e0a52dfb..f4fdb8bb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ S3method(as_dataloader,array) S3method(as_dataloader,dataloader) S3method(as_dataloader,dataset) S3method(as_dataloader,default) +S3method(as_dataloader,iterable_dataset) S3method(as_dataloader,list) S3method(as_dataloader,matrix) S3method(as_dataloader,numeric) diff --git a/R/as_dataloader.R b/R/as_dataloader.R index 60e5c76a..7c4ba9a3 100644 --- a/R/as_dataloader.R +++ b/R/as_dataloader.R @@ -53,6 +53,11 @@ as_dataloader.dataset <- function(x, ..., batch_size = 32) { torch::dataloader(dataset = x, batch_size = batch_size, ...) } +#' @inheritParams as_dataloader.dataset +#' @describeIn as_dataloader Converts a [torch::iterable_dataset()] into a [torch::dataloader()] +#' @export +as_dataloader.iterable_dataset <- as_dataloader.dataset + #' @describeIn as_dataloader Converts a list of tensors or arrays with the same #' size in the first dimension to a [torch::dataloader()] #' @export diff --git a/R/callbacks.R b/R/callbacks.R index b41ff5b8..73307207 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -227,12 +227,22 @@ luz_callback_progress <- luz_callback( } }, initialize_progress_bar = function(split) { - format <- ":current/:total [:bar]" + total <- length(ctx$data) # ctx$data is the current dataset - can be the validation or training. + + if (!is.na(total)) { + format <- ":current/:total [:bar]" + } else { + format <- ":current/unk [:spin]" + } # Specially for testing purposes we don't want to have the # progress bar showing the ETA. if (getOption("luz.show_progress_bar_eta", TRUE)) { - format <- paste0(format, " - ETA: :eta") + if (!is.na(format)) { + format <- paste0(format, " - ETA: :eta") + } else { + format <- paste0(format, " - Rate: :tick_rate iter/s") + } } metrics <- ctx$metrics[[split]] @@ -246,7 +256,6 @@ luz_callback_progress <- luz_callback( show_after <- if (getOption("luz.force_progress_bar", FALSE)) 0 else 0.2 format <- paste0(c(format, abbrevs), collapse = " - ") - total <- length(ctx$data) # ctx$data is the current dataset - can be the validation or training. self$pb <- progress::progress_bar$new( force = getOption("luz.force_progress_bar", FALSE), diff --git a/man/as_dataloader.Rd b/man/as_dataloader.Rd index 44fb2b26..428f4055 100644 --- a/man/as_dataloader.Rd +++ b/man/as_dataloader.Rd @@ -3,6 +3,7 @@ \name{as_dataloader} \alias{as_dataloader} \alias{as_dataloader.dataset} +\alias{as_dataloader.iterable_dataset} \alias{as_dataloader.list} \alias{as_dataloader.dataloader} \alias{as_dataloader.matrix} @@ -15,6 +16,8 @@ as_dataloader(x, ...) \method{as_dataloader}{dataset}(x, ..., batch_size = 32) +\method{as_dataloader}{iterable_dataset}(x, ..., batch_size = 32) + \method{as_dataloader}{list}(x, ...) \method{as_dataloader}{dataloader}(x, ...) @@ -52,6 +55,8 @@ experiments. \itemize{ \item \code{as_dataloader(dataset)}: Converts a \code{\link[torch:dataset]{torch::dataset()}} to a \code{\link[torch:dataloader]{torch::dataloader()}}. +\item \code{as_dataloader(iterable_dataset)}: Converts a \code{\link[torch:iterable_dataset]{torch::iterable_dataset()}} into a \code{\link[torch:dataloader]{torch::dataloader()}} + \item \code{as_dataloader(list)}: Converts a list of tensors or arrays with the same size in the first dimension to a \code{\link[torch:dataloader]{torch::dataloader()}} diff --git a/tests/testthat/_snaps/callbacks.md b/tests/testthat/_snaps/callbacks.md index 95f0d1d1..148c388f 100644 --- a/tests/testthat/_snaps/callbacks.md +++ b/tests/testthat/_snaps/callbacks.md @@ -440,3 +440,47 @@ x Callbacks must have class but got i Perhaps you forgot to initialize the callback? +# can get progress when using iterable datasets + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit( + get_iterable_ds(), verbose = TRUE, epochs = 2, valid_data = get_iterable_ds(), + ) + }) + Message + + 1/unk [-] - Loss: 1.776 + + 2/unk [\] - Loss: 1.6358 + + 3/unk [|] - Loss: 1.6954 + + 1/unk [-] - Loss: 1.664 + + 2/unk [\] - Loss: 1.4837 + + 3/unk [|] - Loss: 1.6957 + + 4/unk [/] - Loss: 1.3467 + Train metrics: Loss: 1.6954 + Valid metrics: Loss: 1.3467 + Epoch 2/2 + + 1/unk [-] - Loss: 1.763 + + 2/unk [\] - Loss: 1.6215 + + 3/unk [|] - Loss: 1.6819 + + 1/unk [-] - Loss: 1.659 + + 2/unk [\] - Loss: 1.4758 + + 3/unk [|] - Loss: 1.6868 + + + 4/unk [/] - Loss: 1.34 + Train metrics: Loss: 1.6819 + Valid metrics: Loss: 1.34 + diff --git a/tests/testthat/test-callbacks.R b/tests/testthat/test-callbacks.R index 52781adb..c2b01f7d 100644 --- a/tests/testthat/test-callbacks.R +++ b/tests/testthat/test-callbacks.R @@ -133,3 +133,67 @@ test_that("improve error message when you provide a unitinitilized callback", { }) }) + +test_that("can get progress when using iterable datasets", { + + torch::torch_manual_seed(1) + set.seed(1) + + model <- get_model() + + get_iterable_ds <- torch::iterable_dataset( + "iterable_ds", + initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) { + self$len <- len + self$x <- torch::torch_randn(size = c(len, x_size)) + self$y <- torch::torch_randn(size = c(len, y_size)) + }, + .iter = function() { + i <- 0 + function() { + i <<- i + 1 + + if (i > self$len) { + return(coro::exhausted()) + } + + list( + x = self$x[i,..], + y = self$y[i,..] + ) + } + } + ) + + ds <- get_iterable_ds() + dl <- torch::dataloader(ds, batch_size = 32) + + mod <- model %>% + setup( + loss = torch::nn_mse_loss(), + optimizer = torch::optim_adam, + ) + + + withr::with_options(list( + luz.force_progress_bar = TRUE, + luz.show_progress_bar_eta = FALSE, + width = 80), { + + expect_snapshot({ + expect_message({ + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit( + get_iterable_ds(), + verbose = TRUE, + epochs = 2, + valid_data = get_iterable_ds(), + ) + }) + }) + }) + + + +})