Skip to content

Commit

Permalink
Bugfix: inconsistent streaming dataloader state (specific to Streamin…
Browse files Browse the repository at this point in the history
…gDataset) (#318)
  • Loading branch information
bhimrazy authored Aug 14, 2024
1 parent 9791488 commit 4dfd98c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def set_drop_last(self, drop_last: bool) -> None:
for dataset in self._datasets:
dataset.set_drop_last(drop_last)

def reset_state_dict(self) -> None:
"""Reset the state of the dataset."""
for dataset in self._datasets:
dataset.reset_state_dict()

def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
if any(not isinstance(d, StreamingDataset) for d in datasets):
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
Expand Down
13 changes: 12 additions & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ def __iter__(self) -> Any:
self.current_epoch += 1
self._num_samples_yielded_combined = {}
self._num_samples_yielded_streaming = 0
self.dataset.reset_state_dict()

self.dataset.set_epoch(self.current_epoch)

Expand Down Expand Up @@ -700,13 +701,23 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:

# Inform we are resuming and disable resetting the StreamingDataLoader state.
# This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
self.restore = True
# self.restore = True

if isinstance(self.dataset, CombinedStreamingDataset):
self.dataset._set_use_streaming_dataloader(True)
self.dataset.load_state_dict(obj)

# Inform that the dataloader is resuming.
# TODO: Check if the number of samples yielded is less than the length of the dataset.
# Also, len is not available for CombinedStreamingDataset incase of provided weights.
self.restore = True

elif isinstance(self.dataset, StreamingDataset):
self.dataset.load_state_dict(obj["dataset"])

# Inform that the dataloader is resuming.
if self._num_samples_yielded_streaming < len(self.dataset):
self.restore = True
else:
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")

Expand Down
14 changes: 12 additions & 2 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,14 @@ def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.global_index >= self.stop_length:
self.current_epoch += 1
self.reset_state_dict()
raise StopIteration

# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == self.num_chunks:
self.current_epoch += 1
self.reset_state_dict()
raise StopIteration

# reset index
Expand Down Expand Up @@ -392,7 +394,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int

return {
"num_samples_yielded": num_samples_yielded,
"num_workers": num_workers,
"num_workers": num_workers or 1,
"batch_size": batch_size,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
Expand All @@ -411,13 +413,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# the state is restored within the workers
self._state_dict = state_dict

def reset_state_dict(self) -> None:
self._state_dict = None

def _validate_state_dict(self) -> None:
assert self._state_dict
assert self.worker_env
assert self.cache

state: Dict[str, Any] = self._state_dict

if state["shuffle"] != self.shuffle:
raise ValueError(
"The provided `shuffle` state doesn't match the current one. "
Expand Down Expand Up @@ -471,6 +475,12 @@ def _validate_state_dict(self) -> None:
f"Found `{self.drop_last}` instead of `{state['drop_last']}`."
)

if state["num_samples_yielded"] > len(self):
raise ValueError(
"The provided `num_samples_yielded` state is greater than the dataset length. "
f"Found `{state['num_samples_yielded']}` instead of `{len(self)}`."
)

def reset(self) -> None:
# undo all the properties associated with original dataset
default_properties: Dict[str, Any] = {
Expand Down
3 changes: 3 additions & 0 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
pass

def reset_state_dict(self):
pass


def test_combined_dataset_num_samples_yield():
dataset = TestCombinedStreamingDataset(
Expand Down
128 changes: 128 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
pass

def reset_state_dict(self):
pass


def test_streaming_dataloader():
dataset = TestCombinedStreamingDataset(
Expand Down Expand Up @@ -202,3 +205,128 @@ def test_dataloader_no_workers(tmpdir):
assert len(dataset) == 1000
assert len(dataloader) == 1000
assert len(dataset) == 1000


@pytest.mark.timeout(120)
def test_dataloader_with_loading_states(tmpdir):
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

dataset = StreamingDataset(str(tmpdir), shuffle=True)

# Test dataloader without explicit num workers
dataloader = StreamingDataLoader(dataset, batch_size=4)
dataloader.load_state_dict(dataloader.state_dict())
batch = next(iter(dataloader))
assert len(batch) == 4, "Batch size should be 4"
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Test dataloader with num workers
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break
dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore
# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
count += 1
assert count == 15, "There should be atleast 15 batches remaining in the first epoch"
assert not dataloader.restore

# Verify batches in the second epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
count += 1
assert count >= 25, "There should be at least 25 batches in the second epoch"

# Verify that the datalaoder can resume after complete last epoch
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 3, "Current epoch should be 3"
count += 1
assert count >= 25, "There should be at least 25 batches in the third epoch"


@pytest.mark.timeout(120)
def test_dataloader_states_with_persistent_workers(tmpdir):
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

dataset = StreamingDataset(str(tmpdir), shuffle=True)

dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break

prev_dataloader_state = dataloader.state_dict()
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, persistent_workers=True)
dataloader.load_state_dict(prev_dataloader_state)
assert dataloader.restore

# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
count += 1
assert count == 15, "There should be atleast 15 batches remaining in the first epoch"
assert not dataloader.restore

# Verify batches in the second epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
count += 1
assert count >= 25, "There should be at least 25 batches in the second epoch"

# Verify that the datalaoder can resume after complete last epoch
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 3, "Current epoch should be 3"
count += 1
assert count >= 25, "There should be at least 25 batches in the third epoch"


@pytest.mark.timeout(60)
def test_resume_dataloader_with_new_dataset(tmpdir):
dataset_1_path = tmpdir.join("dataset_1")
dataset_2_path = tmpdir.join("dataset_2")
for dataset in [dataset_1_path, dataset_2_path]:
cache = Cache(input_dir=str(dataset), chunk_bytes="64MB")
for i in range(50):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(str(dataset_1_path), shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"

dataloader_state = dataloader.state_dict()
dataset = StreamingDataset(str(dataset_2_path), shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
dataloader.load_state_dict(dataloader_state)
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"

0 comments on commit 4dfd98c

Please sign in to comment.