From dff88ca4353d3ee789478b0b71fd9b34ebfd554c Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 5 Sep 2024 10:55:41 +0545 Subject: [PATCH] Adds more tests --- tests/streaming/test_dataloader.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index d160c613..db57e64a 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -369,5 +369,18 @@ def test_combined_dataset_dataloader_states(tmpdir): assert dataloader.restore # Verify remaining batches in the first epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + count += 1 + assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + assert not dataloader.restore + + # Verify batches in the second epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" + count += 1 + assert count >= 25, "There should be at least 25 batches in the second epoch" # TODO: Add more conditions to check the state of the dataloader