From c60d55945a9ea98cf24f111a99d8eedecf91dbda Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 6 Aug 2024 20:33:04 +0200 Subject: [PATCH] Fix StreamingDataset.get_len(num_workers=0) (#311) Co-authored-by: thomas chaton Co-authored-by: tchaton --- src/litdata/streaming/dataset.py | 6 +++--- tests/processing/test_functions.py | 6 +++++- tests/streaming/test_dataloader.py | 17 ++++++++++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 66087056..2453ace2 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -197,13 +197,13 @@ def set_num_workers(self, num_workers: int) -> None: self.num_workers = num_workers or 1 def get_len(self, num_workers: int, batch_size: int) -> int: - self.num_workers = num_workers - self.batch_size = batch_size + self.set_num_workers(num_workers) + self.set_batch_size(batch_size) worker_env = _WorkerEnv.detect() if self.shuffler is None: cache = self._create_cache(worker_env=worker_env) self.shuffler = self._create_shuffler(cache) - return self.shuffler.get_len(self.distributed_env, num_workers, batch_size, self.current_epoch) + return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch) def __iter__(self) -> "StreamingDataset": # When the StreamingDataset is used within map or optimize, let's refetch the distributed env. diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 2eb4ea7e..80eec0ba 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -176,7 +176,7 @@ def test_optimize_append_overwrite(tmpdir): assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)] -@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow") +@pytest.mark.skipif(sys.platform == "win32", reason="too slow") def test_optimize_checkpoint_in_none_and_append_mode(tmpdir): output_dir = str(tmpdir / "output_dir") @@ -188,6 +188,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir): chunk_size=1, num_workers=2, use_checkpoint=True, + start_method="fork", ) # check that the checkpoints are created @@ -201,6 +202,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir): chunk_size=1, num_workers=2, use_checkpoint=True, + start_method="fork", ) ds = StreamingDataset(output_dir) @@ -221,6 +223,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir): num_workers=2, use_checkpoint=True, mode="append", + start_method="fork", ) # check that the checkpoints are created @@ -240,6 +243,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir): num_workers=2, use_checkpoint=True, mode="append", + start_method="fork", ) ds = StreamingDataset(output_dir) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index f0a5e138..768cb4b5 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -3,7 +3,7 @@ import pytest import torch from litdata.constants import _VIZ_TRACKER_AVAILABLE -from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader +from litdata.streaming import Cache, CombinedStreamingDataset, StreamingDataLoader, StreamingDataset from litdata.streaming import dataloader as streaming_dataloader_module from torch import tensor @@ -187,3 +187,18 @@ def test_custom_collate_multiworker(): # Try calling the state_dict. No error should follow _state_dict = dataloader.state_dict() + + +def test_dataloader_no_workers(tmpdir): + cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") + for i in range(1000): + cache[i] = i + + cache.done() + cache.merge() + + dataset = StreamingDataset(str(tmpdir), shuffle=True) + dataloader = StreamingDataLoader(dataset) + assert len(dataset) == 1000 + assert len(dataloader) == 1000 + assert len(dataset) == 1000