Skip to content

Commit

Permalink
Add support for iterable datasets. (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel authored Aug 29, 2023
1 parent e98c62f commit 6921ba8
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 3 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,5 @@ Collate:
'module.R'
'reexports.R'
'serialization.R'
Remotes:
mlverse/torch
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(as_dataloader,array)
S3method(as_dataloader,dataloader)
S3method(as_dataloader,dataset)
S3method(as_dataloader,default)
S3method(as_dataloader,iterable_dataset)
S3method(as_dataloader,list)
S3method(as_dataloader,matrix)
S3method(as_dataloader,numeric)
Expand Down
5 changes: 5 additions & 0 deletions R/as_dataloader.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ as_dataloader.dataset <- function(x, ..., batch_size = 32) {
torch::dataloader(dataset = x, batch_size = batch_size, ...)
}

#' @inheritParams as_dataloader.dataset
#' @describeIn as_dataloader Converts a [torch::iterable_dataset()] into a [torch::dataloader()]
#' @export
as_dataloader.iterable_dataset <- as_dataloader.dataset

#' @describeIn as_dataloader Converts a list of tensors or arrays with the same
#' size in the first dimension to a [torch::dataloader()]
#' @export
Expand Down
15 changes: 12 additions & 3 deletions R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,22 @@ luz_callback_progress <- luz_callback(
}
},
initialize_progress_bar = function(split) {
format <- ":current/:total [:bar]"
total <- length(ctx$data) # ctx$data is the current dataset - can be the validation or training.

if (!is.na(total)) {
format <- ":current/:total [:bar]"
} else {
format <- ":current/unk [:spin]"
}

# Specially for testing purposes we don't want to have the
# progress bar showing the ETA.
if (getOption("luz.show_progress_bar_eta", TRUE)) {
format <- paste0(format, " - ETA: :eta")
if (!is.na(format)) {
format <- paste0(format, " - ETA: :eta")
} else {
format <- paste0(format, " - Rate: :tick_rate iter/s")
}
}

metrics <- ctx$metrics[[split]]
Expand All @@ -246,7 +256,6 @@ luz_callback_progress <- luz_callback(
show_after <- if (getOption("luz.force_progress_bar", FALSE)) 0 else 0.2

format <- paste0(c(format, abbrevs), collapse = " - ")
total <- length(ctx$data) # ctx$data is the current dataset - can be the validation or training.

self$pb <- progress::progress_bar$new(
force = getOption("luz.force_progress_bar", FALSE),
Expand Down
5 changes: 5 additions & 0 deletions man/as_dataloader.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions tests/testthat/_snaps/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,47 @@
x Callbacks must have class <LuzCallback> but got <function>
i Perhaps you forgot to initialize the callback?

# can get progress when using iterable datasets

Code
expect_message({
output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(
get_iterable_ds(), verbose = TRUE, epochs = 2, valid_data = get_iterable_ds(),
)
})
Message
1/unk [-] - Loss: 1.776
2/unk [\] - Loss: 1.6358
3/unk [|] - Loss: 1.6954
1/unk [-] - Loss: 1.664
2/unk [\] - Loss: 1.4837
3/unk [|] - Loss: 1.6957
4/unk [/] - Loss: 1.3467
Train metrics: Loss: 1.6954
Valid metrics: Loss: 1.3467
Epoch 2/2
1/unk [-] - Loss: 1.763
2/unk [\] - Loss: 1.6215
3/unk [|] - Loss: 1.6819
1/unk [-] - Loss: 1.659
2/unk [\] - Loss: 1.4758
3/unk [|] - Loss: 1.6868
4/unk [/] - Loss: 1.34
Train metrics: Loss: 1.6819
Valid metrics: Loss: 1.34

64 changes: 64 additions & 0 deletions tests/testthat/test-callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,67 @@ test_that("improve error message when you provide a unitinitilized callback", {
})

})

test_that("can get progress when using iterable datasets", {

torch::torch_manual_seed(1)
set.seed(1)

model <- get_model()

get_iterable_ds <- torch::iterable_dataset(
"iterable_ds",
initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) {
self$len <- len
self$x <- torch::torch_randn(size = c(len, x_size))
self$y <- torch::torch_randn(size = c(len, y_size))
},
.iter = function() {
i <- 0
function() {
i <<- i + 1

if (i > self$len) {
return(coro::exhausted())
}

list(
x = self$x[i,..],
y = self$y[i,..]
)
}
}
)

ds <- get_iterable_ds()
dl <- torch::dataloader(ds, batch_size = 32)

mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
)


withr::with_options(list(
luz.force_progress_bar = TRUE,
luz.show_progress_bar_eta = FALSE,
width = 80), {

expect_snapshot({
expect_message({
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(
get_iterable_ds(),
verbose = TRUE,
epochs = 2,
valid_data = get_iterable_ds(),
)
})
})
})



})

0 comments on commit 6921ba8

Please sign in to comment.