Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resuming dataset state #217

Merged
merged 17 commits into from
Jul 11, 2024
7 changes: 5 additions & 2 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __next__(self) -> Any:
index=index,
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1 :],
is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
)
)
Expand Down Expand Up @@ -520,9 +520,12 @@ def _replay_chunks_sampling(

for worker_idx, intervals in workers_intervals.items():
for interval in intervals:
size = interval[-1] - interval[0]
size = interval[2] - interval[1]
if indexes[worker_idx] >= size:
indexes[worker_idx] -= size
chunks_index[worker_idx] += 1
else:
# We've reached the chunk where resuming needs to take place (for this worker)
break

return chunks_index, indexes
68 changes: 64 additions & 4 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import json
import os
import random
import shutil
import sys
from time import sleep
from unittest import mock

import numpy as np
import pytest
import torch
from litdata import train_test_split
from litdata import optimize, train_test_split
from litdata.constants import _ZSTD_AVAILABLE
from litdata.processing import functions
from litdata.streaming import Cache
Expand Down Expand Up @@ -793,6 +794,57 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir):
assert not torch.equal(batch_1, batch_2)


def _simple_preprocess(_):
for _ in range(10):
yield torch.randint(0, 100, size=(10,), dtype=torch.int64)


def _get_simulated_s3_dataloader(cache_dir, data_dir):
dataset = EmulateS3StreamingDataset(
input_dir=Dir(cache_dir, data_dir),
item_loader=TokensLoader(block_size=10),
)
return StreamingDataLoader(dataset, batch_size=2, num_workers=1)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
"""This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have
the same size."""
s3_cache_dir = str(tmpdir / "s3cache")
optimize_cache_dir = str(tmpdir / "optimize_cache")
data_dir = str(tmpdir / "optimized")
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir)

optimize(
fn=_simple_preprocess,
inputs=list(range(8)),
output_dir=str(tmpdir / "optimized"),
chunk_size=190,
num_workers=4,
)
assert len(os.listdir(tmpdir / "optimized")) > 0

os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
batches_to_fetch = 16
batch_to_resume_from = None
for i, batch in enumerate(train_dataloader):
if i == batches_to_fetch:
dataloader_state = train_dataloader.state_dict()
if i == batches_to_fetch + 1:
batch_to_resume_from = batch
break

shutil.rmtree(s3_cache_dir)
os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
train_dataloader.load_state_dict(dataloader_state)
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state(tmpdir, monkeypatch):
seed_everything(42)
Expand Down Expand Up @@ -915,19 +967,27 @@ def test_replay_sampling():

def test_replay_chunks_sampling():
chunks_replica = range(10)
intervals_replica = [(i, i + 5) for i in range(0, 50, 5)]
intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)]
workers_chunks, workers_intervals = _associate_chunks_to_workers(
_WorkerEnv(2, 0), chunks_replica, intervals_replica
)
assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]}
assert workers_intervals == {
0: [(0, 5), (10, 15), (20, 25), (30, 35), (40, 45)],
1: [(5, 10), (15, 20), (25, 30), (35, 40), (45, 50)],
0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)],
1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)],
}
assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1})
assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3})
assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2})

# Test that replay stops at the right chunk
workers_intervals = {0: [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 21, 21), (21, 21, 30, 30)]}
indexes = {0: 15}
# Replay should stop at chunk index 1, because 15 - 10 = 5, which fits into with chunk idx 1
chunk_indexes, indexes = _replay_chunks_sampling(workers_intervals, indexes)
assert chunk_indexes == {0: 1}
assert indexes == {0: 5}
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"compression",
Expand Down
Loading