Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 16, 2024
1 parent 0623680 commit 265c4e9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
9 changes: 4 additions & 5 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def __iter__(self) -> "StreamingDataset":
self.global_index = 0
self.index = 0

self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
self.has_triggered_download = False
self.last_time = time()

Expand All @@ -266,9 +265,6 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No
num_samples_yielded = self._state_dict["num_samples_yielded"]

# replay sampling from each worker / chunks using the batch size
# workers_chunks, workers_intervals = _associate_chunks_to_workers(
# self.worker_env, chunks_replica, intervals_replica
# )
indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)
# TODO: Change _replay_chunks_sampling to accept a list
chunks_index, indexes = _replay_chunks_sampling(
Expand Down Expand Up @@ -315,7 +311,10 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
def __next__(self) -> Any:
# Prevent to create more batch on a given process
# print(torch.distributed.get_rank(), self.global_index, len(self), self.stop_length)
if self.global_index >= self.stop_length:
stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
print(f"{self.global_index=}, {stop_length=}")
# TODO: This is stopping too early, length is not correct
if self.global_index >= stop_length:
self.current_epoch += 1
raise StopIteration

Expand Down
16 changes: 13 additions & 3 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,10 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False):
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.timeout(60)
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("shuffle", [
# True,
False,
])
def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
"""This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have
the same size."""
Expand All @@ -822,7 +825,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):

optimize(
fn=_simple_preprocess,
inputs=list(range(5)),
inputs=list(range(8)),
output_dir=str(tmpdir / "optimized"),
chunk_size=190,
num_workers=4,
Expand All @@ -834,16 +837,23 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
batches_to_fetch = 16
batch_to_resume_from = None
dataloader_state = None
# assert len(train_dataloader) == 5 # TODO: This length is wrong
# 8 * 100 tokens = 800 tokens
# 800 / 10 = 80 blocks
# batch size 2: 80 / 2 = 40 batches
# assert len(train_dataloader.dataset) == 80
# assert len(train_dataloader) == 40

for i, batch in enumerate(train_dataloader):
if i == batches_to_fetch:
dataloader_state = train_dataloader.state_dict()
print("saved")
if i == batches_to_fetch + 1:
batch_to_resume_from = batch
break

shutil.rmtree(s3_cache_dir)
os.mkdir(s3_cache_dir)
print("resume")
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle)
assert dataloader_state is not None
assert batch_to_resume_from is not None
Expand Down

0 comments on commit 265c4e9

Please sign in to comment.