Skip to content

Commit

Permalink
Refactor test_combined.py to fix restore state issue
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Sep 9, 2024
1 parent be77b58 commit 86ccd99
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def test_combined_dataset_dataloader_states(tmpdir):
dataloader = StreamingDataLoader(combined_dataset, batch_size=4)
assert not dataloader.restore
dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore
assert not dataloader.restore
batch = next(iter(dataloader))
assert len(batch) == 4, "Batch size should be 4"
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"
Expand All @@ -968,8 +968,9 @@ def test_combined_dataset_dataloader_states(tmpdir):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break

dataloader.load_state_dict(dataloader.state_dict())
state_dict = dataloader.state_dict()
print(state_dict)
dataloader.load_state_dict(state_dict)
assert dataloader.restore

# Verify remaining batches in the first epoch
Expand Down

0 comments on commit 86ccd99

Please sign in to comment.