diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index a725cfa6..7b46bed8 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -20,7 +20,6 @@ from litdata.streaming import Cache from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( - _associate_chunks_and_internals_to_ranks, _associate_chunks_and_internals_to_workers, _intra_node_chunk_shuffle, ) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 981576ba..cfba9a5e 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -811,10 +811,13 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) -@pytest.mark.parametrize("shuffle", [ - # True, - False, -]) +@pytest.mark.parametrize( + "shuffle", + [ + # True, + False, + ], +) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" @@ -840,9 +843,9 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): # 8 * 100 tokens = 800 tokens # 800 / 10 = 80 blocks # batch size 2: 80 / 2 = 40 batches - # assert len(train_dataloader.dataset) == 80 - # assert len(train_dataloader) == 40 - + # assert len(train_dataloader.dataset) == 80 + # assert len(train_dataloader) == 40 + for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict()