diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 50ee71c1..4ad656db 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -555,7 +555,7 @@ def __init__( profile_dir: Optional[str] = None, prefetch_factor: Optional[int] = None, shuffle: Optional[bool] = None, - drop_last: Optional[bool] = False, + drop_last: Optional[bool] = None, collate_fn: Optional[Callable] = None, **kwargs: Any, ) -> None: # pyright: ignore diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 9486e9a9..3986a526 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -600,36 +600,46 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) L = len(dataset) - assert len(dataset) == L + assert L == 20 for i in range(L): sequence = dataset[i] assert sequence[0].item() == i * block_size assert sequence[-1].item() == (i + 1) * block_size - 1 + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "0") + monkeypatch.setenv("NNODES", "1") dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + assert dataset.drop_last # in distributed setting, this is forced automatically - dataset.distributed_env = _DistributedEnv(2, 0, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) - - assert len(dataloader) == 5 - - expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]] + # L = 20, world size 2, num workers 2 + # L / (2 * 2) = 5 items per worker + # drop last -> 4 items per worker + # batch size = 2 -> 2 batches per worker -> len(dataloader) = 4 + assert len(dataloader) == 4 + expected = [[0, 10], [40, 50], [20, 30], [60, 70]] + returned = [] for batch_idx, batch in enumerate(dataloader): - x = [batch[0][0].item(), batch[1][0].item()] - # breakpoint() - # assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + returned.append(batch[:, 0].tolist()) + assert returned == expected - dataset.distributed_env = _DistributedEnv(2, 1, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False) - - assert len(dataloader) == 5 - - expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]] + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "1") + monkeypatch.setenv("NNODES", "1") + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + assert dataset.drop_last # in distributed setting, this is forced automatically + assert len(dataloader) == 4 + + expected = [[80, 90], [120, 130], [100, 110], [140, 150]] + returned = [] for batch_idx, batch in enumerate(dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + returned.append(batch[:, 0].tolist()) + assert returned == expected @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")