Skip to content

Commit

Permalink
Fix StreamingDataset.get_len(num_workers=0) (#311)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2024
1 parent 777e8de commit c60d559
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ def set_num_workers(self, num_workers: int) -> None:
self.num_workers = num_workers or 1

def get_len(self, num_workers: int, batch_size: int) -> int:
self.num_workers = num_workers
self.batch_size = batch_size
self.set_num_workers(num_workers)
self.set_batch_size(batch_size)
worker_env = _WorkerEnv.detect()
if self.shuffler is None:
cache = self._create_cache(worker_env=worker_env)
self.shuffler = self._create_shuffler(cache)
return self.shuffler.get_len(self.distributed_env, num_workers, batch_size, self.current_epoch)
return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch)

def __iter__(self) -> "StreamingDataset":
# When the StreamingDataset is used within map or optimize, let's refetch the distributed env.
Expand Down
6 changes: 5 additions & 1 deletion tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_optimize_append_overwrite(tmpdir):
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]


@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow")
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
output_dir = str(tmpdir / "output_dir")

Expand All @@ -188,6 +188,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
chunk_size=1,
num_workers=2,
use_checkpoint=True,
start_method="fork",
)

# check that the checkpoints are created
Expand All @@ -201,6 +202,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
chunk_size=1,
num_workers=2,
use_checkpoint=True,
start_method="fork",
)

ds = StreamingDataset(output_dir)
Expand All @@ -221,6 +223,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
num_workers=2,
use_checkpoint=True,
mode="append",
start_method="fork",
)

# check that the checkpoints are created
Expand All @@ -240,6 +243,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
num_workers=2,
use_checkpoint=True,
mode="append",
start_method="fork",
)

ds = StreamingDataset(output_dir)
Expand Down
17 changes: 16 additions & 1 deletion tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch
from litdata.constants import _VIZ_TRACKER_AVAILABLE
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader
from litdata.streaming import Cache, CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
from litdata.streaming import dataloader as streaming_dataloader_module
from torch import tensor

Expand Down Expand Up @@ -187,3 +187,18 @@ def test_custom_collate_multiworker():

# Try calling the state_dict. No error should follow
_state_dict = dataloader.state_dict()


def test_dataloader_no_workers(tmpdir):
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
for i in range(1000):
cache[i] = i

cache.done()
cache.merge()

dataset = StreamingDataset(str(tmpdir), shuffle=True)
dataloader = StreamingDataLoader(dataset)
assert len(dataset) == 1000
assert len(dataloader) == 1000
assert len(dataset) == 1000

0 comments on commit c60d559

Please sign in to comment.