diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index d160c613..db57e64a 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -369,5 +369,18 @@ def test_combined_dataset_dataloader_states(tmpdir): assert dataloader.restore # Verify remaining batches in the first epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + count += 1 + assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + assert not dataloader.restore + + # Verify batches in the second epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" + count += 1 + assert count >= 25, "There should be at least 25 batches in the second epoch" # TODO: Add more conditions to check the state of the dataloader