From f52a501440091e826feef03b983a426074d8f72a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Jul 2024 14:15:05 +0000 Subject: [PATCH] fix epoch reshuffling test --- tests/streaming/test_dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index a32d6601..0183f3ae 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -684,7 +684,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") -def test_resumable_dataset_two_workers(tmpdir): +def test_dataset_reshuffling_every_epoch(tmpdir): seed_everything(42) data_dir = os.path.join(tmpdir, "data") @@ -786,7 +786,6 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=True ) - dataset.current_epoch = 1 dataloader = StreamingDataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1, persistent_workers=True) batches_epoch_1 = [] @@ -800,9 +799,7 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): batches_epoch_2.append(batch) assert len(os.listdir(cache_dir)) == 51 - - for batch_1, batch_2 in zip(batches_epoch_1, batches_epoch_2): - assert not torch.equal(batch_1, batch_2) + assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2)) def _simple_preprocess(_):