Skip to content

Commit

Permalink
Lr scheduler state (#137)
Browse files Browse the repository at this point in the history
* Regain lr scheduler state dict.

* Correctly recover the learning rate scheduler.
  • Loading branch information
dfalbel authored Sep 12, 2023
1 parent f7e160f commit 6a8de1a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
6 changes: 6 additions & 0 deletions R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ luz_callback_lr_scheduler <- luz_callback(
rlang::abort(glue::glue("opt_name '{self$opt_name}' not found in ctx$optimizers."))

self$scheduler <- self$lr_scheduler_fn(ctx$optimizers[[self$opt_name]])
},
state_dict = function() {
self$scheduler$state_dict()
},
load_state_dict = function(state_dict) {
self$scheduler$load_state_dict(state_dict)
}
)

Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/_snaps/callbacks-resume.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# resuming a model with a lr scheduler callback is correct

Code
result <- model %>% fit(list(x, y), callbacks = list(autoresume,
luz_callback_lr_scheduler(lr_step, step_size = 1L),
luz_callback_simulate_failure(at_epoch = 11L), luz_callback_lr_progress()),
verbose = FALSE)
Message
lr=1e-06
lr=1e-06
lr=1e-06
lr=1e-06
lr=1e-06
lr=1e-07
lr=1e-08
lr=1e-09
lr=1e-10
lr=1e-11

55 changes: 55 additions & 0 deletions tests/testthat/test-callbacks-resume.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,58 @@ test_that("can use the resume_from callback", {
tr2$weights[[1]]
)
})

test_that("resuming a model with a lr scheduler callback is correct", {

x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)

model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)

luz_callback_lr_progress <- luz_callback(
on_epoch_begin = function() {
rlang::inform(glue::glue("lr={ctx$opt$param_groups[[1]]$lr}"))
}
)

luz_callback_simulate_failure <- luz_callback(
initialize = function(at_epoch) {
self$at_epoch = at_epoch
},
on_epoch_begin = function() {
if (ctx$epoch>=self$at_epoch) rlang::abort("simulated failure")
}
)

autoresume <- luz_callback_auto_resume(path = tempfile())

expect_error(regexp = "simulated failure", {
result <- model %>% fit(
list(x, y),
callbacks = list(
autoresume,
luz_callback_lr_scheduler(lr_step,step_size=1L),
luz_callback_simulate_failure(at_epoch=5L),
luz_callback_lr_progress()
),
verbose = FALSE
)
})

expect_snapshot({
result <- model %>% fit(
list(x, y),
callbacks = list(
autoresume,
luz_callback_lr_scheduler(lr_step,step_size=1L),
luz_callback_simulate_failure(at_epoch=11L),
luz_callback_lr_progress()
),
verbose = FALSE
)
})

})

0 comments on commit 6a8de1a

Please sign in to comment.