diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index f868a453..aafb509f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -799,16 +799,22 @@ def _simple_preprocess(_): yield torch.randint(0, 100, size=(10,), dtype=torch.int64) -def _get_simulated_s3_dataloader(tmpdir): +def _get_simulated_s3_dataloader(cache_dir, data_dir): dataset = EmulateS3StreamingDataset( - input_dir=Dir(str(tmpdir / "s3cache"), str(tmpdir / "optimized")), + input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size=10), ) return StreamingDataLoader(dataset, batch_size=2, num_workers=1) +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir): + """This test is constructed to test resuming from a chunk past the first chunk, when + subsequent chunks don't have the same size.""" + s3_cache_dir = str(tmpdir / "s3cache") + data_dir = str(tmpdir / "optimized") + optimize( fn=_simple_preprocess, inputs=list(range(8)), @@ -816,12 +822,12 @@ def test_dataset_resume_on_future_chunks(tmpdir): chunk_size=190, num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file + assert len(os.listdir(tmpdir / "optimized")) > 1 - os.mkdir(tmpdir / "s3cache") - shutil.rmtree("/cache/chunks", ignore_errors=True) # TODO + os.mkdir(s3_cache_dir) + shutil.rmtree("/cache/chunks", ignore_errors=True) - train_dataloader = _get_simulated_s3_dataloader(tmpdir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) batches_to_fetch = 16 batch_to_resume_from = None for i, batch in enumerate(train_dataloader): @@ -831,11 +837,12 @@ def test_dataset_resume_on_future_chunks(tmpdir): batch_to_resume_from = batch break - shutil.rmtree(tmpdir / "s3cache") - os.mkdir(tmpdir / "s3cache") + shutil.rmtree(s3_cache_dir) + os.mkdir(s3_cache_dir) shutil.rmtree("/cache/chunks", ignore_errors=True) - train_dataloader = _get_simulated_s3_dataloader(tmpdir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) train_dataloader.load_state_dict(dataloader_state) + # The next batch after resuming must match what we should have gotten next in the initial loop assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)