Skip to content

Commit

Permalink
Fix: Resolve drop_last not passed down from the StreamingDataLoader t…
Browse files Browse the repository at this point in the history
…o the datasets (#147)
  • Loading branch information
tchaton authored Jun 1, 2024
1 parent bb362a0 commit d2802bd
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def set_shuffle(self, shuffle: bool) -> None:
for dataset in self._datasets:
dataset.set_shuffle(shuffle)

def set_drop_last(self, drop_last: bool) -> None:
"""Set the current drop_last to the datasets."""
for dataset in self._datasets:
dataset.set_drop_last(drop_last)

def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
if any(not isinstance(d, StreamingDataset) for d in datasets):
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
Expand Down
4 changes: 4 additions & 0 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(
profile_dir: Optional[str] = None,
prefetch_factor: Optional[int] = None,
shuffle: Optional[bool] = None,
drop_last: Optional[bool] = False,
**kwargs: Any,
) -> None: # pyright: ignore
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
Expand All @@ -551,6 +552,9 @@ def __init__(
if shuffle is not None:
dataset.set_shuffle(shuffle)

if drop_last is not None:
dataset.set_drop_last(drop_last)

shuffle = None

if profile_batches and not _VIZ_TRACKER_AVAILABLE:
Expand Down
3 changes: 3 additions & 0 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __init__(
def set_shuffle(self, shuffle: bool) -> None:
self.shuffle = shuffle

def set_drop_last(self, drop_last: bool) -> None:
self.drop_last = drop_last

def set_epoch(self, current_epoch: int) -> None:
"""Set the current epoch to the dataset on epoch starts.
Expand Down
1 change: 0 additions & 1 deletion status.json

This file was deleted.

21 changes: 20 additions & 1 deletion tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
from unittest.mock import ANY
from unittest.mock import ANY, MagicMock

import pytest
import torch
Expand Down Expand Up @@ -53,6 +53,19 @@ def test_combined_dataset_num_samples_yield_iterate_over_all():
assert len(samples) == 20


def test_drop_last_and_shuffle():
dataset_mock_1 = MagicMock()
dataset_mock_2 = MagicMock()

dataset = TestCombinedStreamingDataset([dataset_mock_1, dataset_mock_2], 42, iterate_over_all=True)
StreamingDataLoader(dataset, shuffle=True, drop_last=True)

dataset_mock_1.set_shuffle.assert_called()
dataset_mock_2.set_shuffle.assert_called()
dataset_mock_1.set_drop_last.assert_called()
dataset_mock_2.set_drop_last.assert_called()


class TestStatefulDataset:
def __init__(self, size, step):
self.size = size
Expand Down Expand Up @@ -177,6 +190,12 @@ def state_dict(self, **kwargs):
def set_epoch(self, current_epoch):
pass

def set_shuffle(self, _):
pass

def set_drop_last(self, _):
pass


def test_combined_dataset():
dataset1 = SimpleDataset(0, 10)
Expand Down
4 changes: 4 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, size, step):
self.step = step
self.counter = 0
self.shuffle = None
self.drop_last = None

def set_shuffle(self, shuffle):
self.shuffle = shuffle
Expand Down Expand Up @@ -41,6 +42,9 @@ def load_state_dict(self, state_dict):
def set_epoch(self, current_epoch):
pass

def set_drop_last(self, drop_last):
self.drop_last = drop_last


class TestCombinedStreamingDataset(CombinedStreamingDataset):
def _check_datasets(self, datasets) -> None:
Expand Down

0 comments on commit d2802bd

Please sign in to comment.