From bbcaaf451a8c526ef72b7734e7bed889109d151a Mon Sep 17 00:00:00 2001 From: deependujha Date: Fri, 12 Jul 2024 13:12:15 +0530 Subject: [PATCH] fix failing tests --- tests/streaming/test_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index da78b02c..b1615198 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -830,6 +830,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) batches_to_fetch = 16 batch_to_resume_from = None + dataloader_state = None for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() @@ -840,6 +841,8 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + assert dataloader_state is not None + assert batch_to_resume_from is not None train_dataloader.load_state_dict(dataloader_state) # The next batch after resuming must match what we should have gotten next in the initial loop assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)