From 86ccd99ecda24b6237bdd1bb14dd57e2e745591f Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 9 Sep 2024 18:32:29 +0545 Subject: [PATCH] Refactor test_combined.py to fix restore state issue --- tests/streaming/test_combined.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 5fee533a..98a3cfca 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -943,7 +943,7 @@ def test_combined_dataset_dataloader_states(tmpdir): dataloader = StreamingDataLoader(combined_dataset, batch_size=4) assert not dataloader.restore dataloader.load_state_dict(dataloader.state_dict()) - assert dataloader.restore + assert not dataloader.restore batch = next(iter(dataloader)) assert len(batch) == 4, "Batch size should be 4" assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)" @@ -968,8 +968,9 @@ def test_combined_dataset_dataloader_states(tmpdir): assert dataloader.current_epoch == 1, "Current epoch should be 1" if batch_idx == 10: break - - dataloader.load_state_dict(dataloader.state_dict()) + state_dict = dataloader.state_dict() + print(state_dict) + dataloader.load_state_dict(state_dict) assert dataloader.restore # Verify remaining batches in the first epoch