Skip to content

Commit

Permalink
updated tests: added case for the complete last iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Sep 9, 2024
1 parent d40e3ca commit be77b58
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,20 @@ def test_combined_dataset_dataloader_states(tmpdir):
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4)
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"

# Verify dataloader state after partial iteration
# Verify dataloader state after complete last iteration
for batch in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
pass
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore
for batch in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
pass

# Verify dataloader state after partial last iteration
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=4)
for batch_idx, batch in enumerate(dataloader):
# assert dataloader.current_epoch == 1, "Current epoch should be 1"
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break

Expand Down

0 comments on commit be77b58

Please sign in to comment.