From 499bb9dcf473e01ac6c0748fec5539527e1be127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 22:57:52 +0000 Subject: [PATCH] debug --- tests/streaming/test_dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 4ec132d4..29b5a058 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -809,11 +809,13 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) -def test_dataset_resume_on_future_chunks(tmpdir): +def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") + optimize_cache_dir = str(tmpdir / "optimize_cache") data_dir = str(tmpdir / "optimized") + monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir) optimize( fn=_simple_preprocess, @@ -825,8 +827,6 @@ def test_dataset_resume_on_future_chunks(tmpdir): assert len(os.listdir(tmpdir / "optimized")) > 1 os.mkdir(s3_cache_dir) - shutil.rmtree("/cache/chunks", ignore_errors=True) - train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) batches_to_fetch = 16 batch_to_resume_from = None @@ -839,7 +839,6 @@ def test_dataset_resume_on_future_chunks(tmpdir): shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) - shutil.rmtree("/cache/chunks", ignore_errors=True) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) train_dataloader.load_state_dict(dataloader_state) # The next batch after resuming must match what we should have gotten next in the initial loop