Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: inconsistent streaming dataloader state (specific to StreamingDataset) #318

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3658d64
chore: Add reset_state_dict method to StreamingDataset
bhimrazy Aug 9, 2024
8eb1c7f
chore: Update num_workers fallback value in StreamingDataset
bhimrazy Aug 9, 2024
10c10b3
fix: Reset dataset state after each epoch
bhimrazy Aug 9, 2024
391c68b
update
tchaton Aug 9, 2024
5d74ed8
Update src/litdata/streaming/dataset.py
tchaton Aug 9, 2024
7412064
feat: Add test for dataloader with loading states
bhimrazy Aug 9, 2024
0290a30
chore: Add test for dataloader with loading states with peristent wor…
bhimrazy Aug 9, 2024
00c2928
rm commment
bhimrazy Aug 9, 2024
25a87b7
🐛 fix: restore only if there are any remaining batches/samples to str…
bhimrazy Aug 11, 2024
678c3fc
added notes to checkout later
bhimrazy Aug 11, 2024
532dacd
Merge branch 'main' into bugfix/316-streaming-dataloader-state
bhimrazy Aug 11, 2024
9866992
add note
bhimrazy Aug 11, 2024
16bc40f
chore: Add test for dataloader resuming after completing last epoch
bhimrazy Aug 11, 2024
d3f9498
feat: Add test for resuming dataloader with new dataset
bhimrazy Aug 11, 2024
6769694
adds type ignore
bhimrazy Aug 11, 2024
81bc537
update timeout and num of samples
bhimrazy Aug 11, 2024
998fe5a
Add explicit test for resuming dataloader with new dataset
bhimrazy Aug 11, 2024
61120a4
chore: add validation for num_samples_yielded
bhimrazy Aug 11, 2024
faa0213
Merge branch 'main' into bugfix/316-streaming-dataloader-state
bhimrazy Aug 12, 2024
d98681c
removed unrequired test, as it was testing for wrong thing, when rese…
bhimrazy Aug 12, 2024
743f0dd
removed the unnecesssary todo
bhimrazy Aug 12, 2024
2db07e0
chore: Add restore flag to dataloader tests
bhimrazy Aug 12, 2024
fc3a960
chore: Add restore flag to dataloader for StreamingDataset
bhimrazy Aug 13, 2024
4a50cac
update
bhimrazy Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def set_drop_last(self, drop_last: bool) -> None:
for dataset in self._datasets:
dataset.set_drop_last(drop_last)

def reset_state_dict(self) -> None:
"""Reset the state of the dataset."""
for dataset in self._datasets:
dataset.reset_state_dict()

def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
if any(not isinstance(d, StreamingDataset) for d in datasets):
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
Expand Down
1 change: 1 addition & 0 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ def __iter__(self) -> Any:
self.current_epoch += 1
self._num_samples_yielded_combined = {}
self._num_samples_yielded_streaming = 0
self.dataset.reset_state_dict()
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved

self.dataset.set_epoch(self.current_epoch)

Expand Down
7 changes: 6 additions & 1 deletion src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,14 @@ def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.global_index >= self.stop_length:
self.current_epoch += 1
self.reset_state_dict()
raise StopIteration

# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == self.num_chunks:
self.current_epoch += 1
self.reset_state_dict()
raise StopIteration

# reset index
Expand Down Expand Up @@ -388,7 +390,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int

return {
"num_samples_yielded": num_samples_yielded,
"num_workers": num_workers,
"num_workers": num_workers or 1,
"batch_size": batch_size,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
Expand All @@ -407,6 +409,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# the state is restored within the workers
self._state_dict = state_dict

def reset_state_dict(self) -> None:
self._state_dict = None

def _validate_state_dict(self) -> None:
assert self._state_dict
assert self.worker_env
Expand Down
3 changes: 3 additions & 0 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
pass

def reset_state_dict(self):
pass


def test_combined_dataset_num_samples_yield():
dataset = TestCombinedStreamingDataset(
Expand Down
46 changes: 46 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
pass

def reset_state_dict(self):
pass


def test_streaming_dataloader():
dataset = TestCombinedStreamingDataset(
Expand Down Expand Up @@ -202,3 +205,46 @@ def test_dataloader_no_workers(tmpdir):
assert len(dataset) == 1000
assert len(dataloader) == 1000
assert len(dataset) == 1000


@pytest.mark.timeout(120)
def test_dataloader_with_loading_states(tmpdir):
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

dataset = StreamingDataset(str(tmpdir), shuffle=True)

# Test dataloader without explicit num workers
dataloader = StreamingDataLoader(dataset, batch_size=4)
dataloader.load_state_dict(dataloader.state_dict())
batch = next(iter(dataloader))
assert len(batch) == 4, "Batch size should be 4"
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Test dataloader with num workers
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break
dataloader.load_state_dict(dataloader.state_dict())

# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
count += 1
assert count == 15, "There should be atleast 15 batches remaining in the first epoch"

# Verify batches in the second epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
count += 1
assert count >= 25, "There should be at least 25 batches in the second epoch"
Loading