Skip to content

Commit

Permalink
fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha committed Jul 12, 2024
1 parent 34bcc6f commit bbcaaf4
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
batches_to_fetch = 16
batch_to_resume_from = None
dataloader_state = None
for i, batch in enumerate(train_dataloader):
if i == batches_to_fetch:
dataloader_state = train_dataloader.state_dict()
Expand All @@ -840,6 +841,8 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
shutil.rmtree(s3_cache_dir)
os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
assert dataloader_state is not None
assert batch_to_resume_from is not None
train_dataloader.load_state_dict(dataloader_state)
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)
Expand Down

0 comments on commit bbcaaf4

Please sign in to comment.