diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py
index 50ee71c1..4ad656db 100644
--- a/src/litdata/streaming/dataloader.py
+++ b/src/litdata/streaming/dataloader.py
@@ -555,7 +555,7 @@ def __init__(
         profile_dir: Optional[str] = None,
         prefetch_factor: Optional[int] = None,
         shuffle: Optional[bool] = None,
-        drop_last: Optional[bool] = False,
+        drop_last: Optional[bool] = None,
         collate_fn: Optional[Callable] = None,
         **kwargs: Any,
     ) -> None:  # pyright: ignore
diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py
index 9486e9a9..3986a526 100644
--- a/tests/streaming/test_dataset.py
+++ b/tests/streaming/test_dataset.py
@@ -600,36 +600,46 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
     dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
 
     L = len(dataset)
-    assert len(dataset) == L
+    assert L == 20
 
     for i in range(L):
         sequence = dataset[i]
         assert sequence[0].item() == i * block_size
         assert sequence[-1].item() == (i + 1) * block_size - 1
 
+    monkeypatch.setenv("WORLD_SIZE", "2")
+    monkeypatch.setenv("GLOBAL_RANK", "0")
+    monkeypatch.setenv("NNODES", "1")
     dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
+    dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
+    assert dataset.drop_last  # in distributed setting, this is forced automatically
 
-    dataset.distributed_env = _DistributedEnv(2, 0, 1)
-    dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
-
-    assert len(dataloader) == 5
-
-    expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]]
+    # L = 20, world size 2, num workers 2
+    # L / (2 * 2) = 5 items per worker
+    # drop last -> 4 items per worker
+    # batch size = 2 -> 2 batches per worker -> len(dataloader) = 4
+    assert len(dataloader) == 4
 
+    expected = [[0, 10], [40, 50], [20, 30], [60, 70]]
+    returned = []
     for batch_idx, batch in enumerate(dataloader):
-        x = [batch[0][0].item(), batch[1][0].item()]
-        # breakpoint()
-        # assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
+        returned.append(batch[:, 0].tolist())
+    assert returned == expected
 
-    dataset.distributed_env = _DistributedEnv(2, 1, 1)
-    dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
-
-    assert len(dataloader) == 5
-
-    expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]]
+    monkeypatch.setenv("WORLD_SIZE", "2")
+    monkeypatch.setenv("GLOBAL_RANK", "1")
+    monkeypatch.setenv("NNODES", "1")
+    dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
+    dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
+    assert dataset.drop_last  # in distributed setting, this is forced automatically
 
+    assert len(dataloader) == 4
+    
+    expected = [[80, 90], [120, 130], [100, 110], [140, 150]]
+    returned = []
     for batch_idx, batch in enumerate(dataloader):
-        assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
+        returned.append(batch[:, 0].tolist())
+    assert returned == expected
 
 
 @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")