Skip to content

Commit

Permalink
Add support for iterate_over_all for the CombinedDataset (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored May 7, 2024
1 parent bc0366d commit 015f21c
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 35 deletions.
77 changes: 71 additions & 6 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# limitations under the License.

import random
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Optional, Sequence

import numpy as np
from torch.utils.data import IterableDataset

from litdata.streaming.dataset import StreamingDataset
Expand All @@ -36,15 +38,38 @@ class CombinedStreamingDataset(IterableDataset):
"""

def __init__(
self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None
self,
datasets: List[StreamingDataset],
seed: int = 42,
weights: Optional[Sequence[float]] = None,
iterate_over_all: bool = True,
) -> None:
""" "
Arguments:
datasets: The list of the StreamingDataset to use.
seed: The random seed to initialize the sampler
weights: The sampling ratio for the datasets
iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets.
Otherwise, it stops as soon as one raises a StopIteration.
"""

self._check_datasets(datasets)

self._seed = seed
self._datasets = datasets
self._weights = weights
self._iterate_over_all = iterate_over_all

num_datasets = len(datasets)

if iterate_over_all and weights:
raise ValueError(
"When `iterate_over_all` is set to True, the weights argument shouldn't be provided.",
" Instead, it will be computed from the inverse of the dataset length.",
)

self._iterate_over_all = iterate_over_all

if weights is None:
# Inversely weighted based on length
self._weights = [1 / float(num_datasets)] * num_datasets
Expand All @@ -56,6 +81,15 @@ def __init__(
self._num_samples_yielded: Optional[List[int]] = None
self._current_epoch = 0

def __len__(self) -> Optional[int]:
if self._iterate_over_all:
return self._get_total_length()
return None

# total length of the datasets
def _get_total_length(self) -> int:
return sum(len(d) for d in self._datasets)

def set_epoch(self, current_epoch: int) -> None:
"""Set the current epoch to the datasets on epoch starts.
Expand Down Expand Up @@ -95,6 +129,7 @@ def __iter__(self) -> Iterator[Any]:
self._weights,
self._use_streaming_dataloader,
num_samples_yielded,
self._iterate_over_all,
)
return self._iterator

Expand Down Expand Up @@ -132,31 +167,61 @@ def __init__(
seed: int,
weights: Sequence[float],
use_streaming_dataloader: bool,
num_samples_yielded: Optional[Any] = None,
num_samples_yielded: Any,
iterate_over_all: bool = False,
) -> None:
self._datasets = datasets
self._dataset_iters = [iter(dataset) for dataset in datasets]
self._dataset_indexes = list(range(len(datasets)))
self._num_samples_yielded = [0 for _ in range(len(datasets))]
self._weights = weights
self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))]
self._original_weights = deepcopy(weights)
self._weights = deepcopy(weights)
self._rng = random.Random(seed)
self._iterate_over_all = iterate_over_all
self._is_done = False

if num_samples_yielded is not None:
self._num_samples_yielded = num_samples_yielded
for _ in range(sum(num_samples_yielded)):
self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)

self._use_streaming_dataloader = use_streaming_dataloader
self._is_done = False

def __next__(self) -> Any:
if self._iterate_over_all:
while True:
try:
if len(self._dataset_indexes) > 1:
dataset_index = self._get_dataset_index()
elif len(self._dataset_indexes) == 1:
dataset_index = self._dataset_indexes[0]
return self._get_sample(dataset_index)
except StopIteration as e:
if len(self._dataset_indexes) == 1:
self._dataset_indexes = list(range(len(self._datasets)))
self._weights = deepcopy(self._original_weights)
raise e

self._dataset_indexes.pop(dataset_index)
self._weights.pop(dataset_index)
self._weights /= np.sum(self._weights)

# stop on the first iteration
return self._get_sample(self._get_dataset_index())

def _get_dataset_index(self) -> int:
# randomly select a dataset index
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
return dataset_index

def _get_sample(self, dataset_index: int) -> Any:
# get the sample
sample = next(self._dataset_iters[dataset_index])

# keep track the sample was fetched
self._num_samples_yielded[dataset_index] += 1

sample = next(self._dataset_iters[dataset_index])

# return a new sample
if self._use_streaming_dataloader:
return {
Expand Down
84 changes: 59 additions & 25 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@ def _check_datasets(self, datasets) -> None:


def test_combined_dataset_num_samples_yield():
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5))
dataset = TestCombinedStreamingDataset(
[range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False
)
dataset_iter = iter(dataset)

data = list(dataset_iter)
assert data == [0, 0, 1, 2, -1, -2, -3, 3, 4, 5, 6, -4, 7, 8, -5, -6, 9, -7, -8]

dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5))
dataset = TestCombinedStreamingDataset(
[range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5), iterate_over_all=False
)
dataset_iter = iter(dataset)

data = list(dataset_iter)
assert data == [0, 0, -1, -2, -3, -4, -5, 1, -6, 2, -7, -8, 3, 4, -9, 5]

dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5))
dataset = TestCombinedStreamingDataset(
[range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5), iterate_over_all=False
)
dataset_iter = iter(dataset)

data = [next(dataset_iter) for _ in range(5)]
Expand All @@ -40,6 +46,13 @@ def test_combined_dataset_num_samples_yield():
assert dataset._iterator._num_samples_yielded == [2, 4]


def test_combined_dataset_num_samples_yield_iterate_over_all():
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True)
assert len(dataset) == 20
samples = list(dataset)
assert len(samples) == 20


class TestStatefulDataset:
def __init__(self, size, step):
self.size = size
Expand Down Expand Up @@ -69,14 +82,20 @@ def load_state_dict(self, state_dict):

def test_combined_dataset_state_dict():
dataset = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)],
42,
weights=(0.5, 0.5),
iterate_over_all=False,
)
assert dataset.state_dict(0, 1) == {}
dataset_iter = iter(dataset)
assert dataset.state_dict(0, 1) == {"0": {"counter": 0}, "1": {"counter": 0}}

dataset2 = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)],
42,
weights=(0.5, 0.5),
iterate_over_all=False,
)
assert dataset2.state_dict(0, 1) == {}

Expand Down Expand Up @@ -111,7 +130,10 @@ def test_combined_dataset_state_dict():
]

dataset2 = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)],
42,
weights=(0.5, 0.5),
iterate_over_all=False,
)
assert dataset2.state_dict(0, 1) == {}
dataset2_iter = iter(dataset2)
Expand All @@ -136,7 +158,7 @@ def test_combined_dataset_state_dict():
],
)
def test_combined_dataset_normalizes_weights(weights, expected):
combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1)
combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, iterate_over_all=False, seed=1)
assert combined_dataset._weights == expected


Expand All @@ -159,21 +181,27 @@ def set_epoch(self, current_epoch):
def test_combined_dataset():
dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345)
dataset = TestCombinedStreamingDataset(
datasets=[dataset1, dataset2], weights=[1.0, 0.0], iterate_over_all=False, seed=12345
)

res = list(dataset)
assert res == list(range(0, 10))

dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345)
dataset = TestCombinedStreamingDataset(
datasets=[dataset1, dataset2], weights=[0.0, 1.0], iterate_over_all=False, seed=12345
)

res = list(dataset)
assert res == list(range(10, 20))

dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
dataset = TestCombinedStreamingDataset(
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
)

res = list(dataset)
assert 9 in res or 19 in res
Expand All @@ -183,7 +211,9 @@ def test_combined_dataset():

dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
dataset = TestCombinedStreamingDataset(
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
)
dataloader = DataLoader(dataset, batch_size=2, num_workers=1)
dataloader_iter = iter(dataloader)
assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1]))
Expand All @@ -193,7 +223,9 @@ def test_combined_dataset():
def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
dataset = TestCombinedStreamingDataset(
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
)
dataloader = StreamingDataLoader(dataset, num_workers=1, batch_size=batch_size, prefetch_factor=1)
dataloader_iter = iter(dataloader)

Expand Down Expand Up @@ -260,7 +292,9 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):

dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True)
dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True)
dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
dataset = CombinedStreamingDataset(
datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345
)
dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2)

assert dataset1.current_epoch == 1
Expand Down Expand Up @@ -454,7 +488,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
{
"dataset": {
"0": {
"num_samples_yielded": 9,
"num_samples_yielded": 8,
"num_workers": 3,
"batch_size": 2,
"current_epoch": 1,
Expand Down Expand Up @@ -482,12 +516,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
},
"current_epoch": 0,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]},
},
{
"dataset": {
"0": {
"num_samples_yielded": 11,
"num_samples_yielded": 9,
"num_workers": 3,
"batch_size": 2,
"current_epoch": 1,
Expand Down Expand Up @@ -515,12 +549,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
},
"current_epoch": 0,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]},
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]},
},
{
"dataset": {
"0": {
"num_samples_yielded": 13,
"num_samples_yielded": 10,
"num_workers": 3,
"batch_size": 2,
"current_epoch": 1,
Expand Down Expand Up @@ -548,7 +582,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
},
"current_epoch": 0,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]},
"num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]},
},
]

Expand Down Expand Up @@ -721,7 +755,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
{
"dataset": {
"0": {
"num_samples_yielded": 9,
"num_samples_yielded": 8,
"num_workers": 3,
"batch_size": 2,
"current_epoch": 2,
Expand Down Expand Up @@ -749,12 +783,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
},
"current_epoch": 1,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]},
},
{
"dataset": {
"0": {
"num_samples_yielded": 11,
"num_samples_yielded": 9,
"num_workers": 3,
"batch_size": 2,
"current_epoch": 2,
Expand Down Expand Up @@ -782,12 +816,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
},
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]},
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]},
},
{
"dataset": {
"0": {
"num_samples_yielded": 13,
"num_samples_yielded": 10,
"num_workers": 3,
"batch_size": 2,
"current_epoch": 2,
Expand Down Expand Up @@ -815,7 +849,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
},
"current_epoch": 1,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]},
"num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]},
},
]

Expand Down
Loading

0 comments on commit 015f21c

Please sign in to comment.