Skip to content

Commit

Permalink
Init Implementation per_stream batching
Browse files Browse the repository at this point in the history
  • Loading branch information
schopra8 committed Dec 20, 2024
1 parent bd116a4 commit efa9858
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
25 changes: 23 additions & 2 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import random
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Optional, Sequence
from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence

from torch.utils.data import IterableDataset

Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
seed: int = 42,
weights: Optional[Sequence[float]] = None,
iterate_over_all: bool = True,
batching_method: Literal["stratified", "per_stream"] = "stratified",
) -> None:
"""Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice.
Expand All @@ -51,6 +52,9 @@ def __init__(
weights: The sampling ratio for the datasets
iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets.
Otherwise, it stops as soon as one raises a StopIteration.
batching_method: When batching_method is "stratified" (default), every sample in a batch is drawn randomly across all datasets.
When batching_method is "per_stream" every sample in a batch is drawn from the same dataset. After each batch, a dataset
is selected at random.
"""
self._check_datasets(datasets)
Expand All @@ -67,6 +71,7 @@ def __init__(
)

self._iterate_over_all = iterate_over_all
self._batching_method = batching_method

if weights is None:
# Weighted based on the dataset length
Expand Down Expand Up @@ -165,6 +170,7 @@ def __iter__(self) -> Iterator[Any]:
self._use_streaming_dataloader,
num_samples_yielded,
self._iterate_over_all,
self._batching_method
)
return self._iterator

Expand Down Expand Up @@ -204,6 +210,7 @@ def __init__(
use_streaming_dataloader: bool,
num_samples_yielded: Any,
iterate_over_all: bool = False,
batching_method: Literal["stratified", "per_stream"] = "stratified",
) -> None:
self._datasets = datasets
self._dataset_iters = [iter(dataset) for dataset in datasets]
Expand All @@ -213,6 +220,8 @@ def __init__(
self._weights = deepcopy(weights)
self._rng = random.Random(seed)
self._iterate_over_all = iterate_over_all
self._batching_method = batching_method
self._cur_dataset_index = -1
self._is_done = False

if num_samples_yielded is not None:
Expand All @@ -234,6 +243,7 @@ def __next__(self) -> Any:
dataset_index = self._get_dataset_index()
elif len(indexes_left) == 1:
dataset_index = indexes_left[0]
self._cur_dataset_index = dataset_index
return self._get_sample(dataset_index)
except StopIteration as e:
if len(indexes_left) == 1:
Expand All @@ -250,11 +260,22 @@ def __next__(self) -> Any:
return self._get_sample(self._get_dataset_index())

def _get_dataset_index(self) -> int:
if self._batching_method == "stratified":
# randomly select a dataset index
self._set_new_dataset_index()
elif self._batching_method == "per_stream":
# randomly select a dataset index, if no previous dataset index exists
if self._cur_dataset_index == -1:
self._set_new_dataset_index()
return self._cur_dataset_index

def _set_new_dataset_index(self):
# randomly select a dataset index
indexes = [index for index in self._dataset_indexes if index is not None]
weights = [w for w in self._weights if w is not None]
(dataset_index,) = self._rng.choices(indexes, weights=weights, k=1)
return dataset_index
self._cur_dataset_index = dataset_index


def _get_sample(self, dataset_index: int) -> Any:
# get the sample
Expand Down
8 changes: 7 additions & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from importlib import reload
from itertools import cycle
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import torch
from torch.utils.data import Dataset, IterableDataset
Expand Down Expand Up @@ -549,6 +549,7 @@ def __init__(
self,
dataset: Union[StreamingDataset, CombinedStreamingDataset],
*args: Any,
batching_method: Literal["stratified", "per_stream"] = "stratified",
batch_size: int = 1,
num_workers: int = 0,
profile_batches: Union[bool, int] = False,
Expand Down Expand Up @@ -626,10 +627,15 @@ def __iter__(self) -> Any:
self._num_samples_yielded_streaming += self.batch_size
yield batch
else:
# Assume, this is a CombinedStreamingDataset.
self.dataset._set_use_streaming_dataloader(True)
assert self.batch_size
# TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key
for batch in super().__iter__():
# Force selection of a new dataset on batch boundaries
# Note, samples may come from several datasets within a batch, depending
# on `CombinedStreamingDataset`'s `batching_method` value.
self.dataset._set_new_dataset_index()
self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore
if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch:
self._num_samples_yielded_combined[self._latest_worker_idx] = [
Expand Down

0 comments on commit efa9858

Please sign in to comment.