Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Chunks deletion issue #375

Merged
merged 11 commits into from
Sep 26, 2024
26 changes: 20 additions & 6 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def run(self) -> None:
try:
self._setup()
self._loop()
self._terminate()
except Exception:
traceback_format = traceback.format_exc()
self.error_queue.put(traceback_format)
Expand All @@ -469,6 +470,19 @@ def _setup(self) -> None:
self._start_uploaders()
self._start_remover()

def _terminate(self) -> None:
"""Make sure all the uploaders, downloaders and removers are terminated."""
for uploader in self.uploaders:
if uploader.is_alive():
uploader.join()

for downloader in self.downloaders:
if downloader.is_alive():
downloader.join()

if self.remover and self.remover.is_alive():
self.remover.join()

def _loop(self) -> None:
num_downloader_finished = 0

Expand Down Expand Up @@ -795,7 +809,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul

chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
if chunks and delete_cached_files and output_dir.path is not None:
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks}")
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks} in cache: {cache_dir}")

merge_cache = Cache(cache_dir, chunk_bytes=1)
node_rank = _get_node_rank()
Expand Down Expand Up @@ -1111,6 +1125,10 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = new_total
if current_total == num_items:
# make sure all processes are terminated
for w in self.workers:
if w.is_alive():
w.join()
break

if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
Expand All @@ -1119,17 +1137,13 @@ def run(self, data_recipe: DataRecipe) -> None:

# Exit early if all the workers are done.
# This means there were some kinda of errors.
# TODO: Check whether this is still required.
deependujha marked this conversation as resolved.
Show resolved Hide resolved
if all(not w.is_alive() for w in self.workers):
raise RuntimeError("One of the worker has failed")

if _TQDM_AVAILABLE:
pbar.close()

# TODO: Check whether this is still required.
if num_nodes == 1:
for w in self.workers:
w.join()

print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

Expand Down
55 changes: 55 additions & 0 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import glob
import os
import random
import shutil
import sys
from pathlib import Path
from unittest import mock

import cryptography
import numpy as np
import pytest
import requests
from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache
Expand Down Expand Up @@ -475,3 +480,53 @@ def test_optimize_with_rsa_encryption(tmpdir):
# encryption=rsa,
# mode="overwrite",
# )


def tokenize(filename: str):
with open(filename, encoding="utf-8") as file:
text = file.read()
text = text.strip().split(" ")
word_to_int = {word: random.randint(1, 1000) for word in set(text)} # noqa: S311
tokenized = [word_to_int[word] for word in text]

yield tokenized


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
def test_optimize_race_condition(tmpdir):
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
# issue: https://github.com/Lightning-AI/litdata/issues/367
# run_commands = [
# "mkdir -p tempdir/custom_texts",
# "curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output tempdir/custom_texts/book1.txt",
# "curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output tempdir/custom_texts/book2.txt",
# ]
shutil.rmtree(f"{tmpdir}/custom_texts", ignore_errors=True)
os.makedirs(f"{tmpdir}/custom_texts", exist_ok=True)

urls = [
"https://www.gutenberg.org/cache/epub/24440/pg24440.txt",
"https://www.gutenberg.org/cache/epub/26393/pg26393.txt",
]

for i, url in enumerate(urls):
print(f"downloading {i+1} file")
with requests.get(url, stream=True, timeout=10) as r:
r.raise_for_status() # Raise an exception for bad status codes

with open(f"{tmpdir}/custom_texts/book{i+1}.txt", "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

print("=" * 100)

train_files = sorted(glob.glob(str(Path(f"{tmpdir}/custom_texts") / "*.txt")))
print("=" * 100)
print(train_files)
print("=" * 100)
optimize(
fn=tokenize,
inputs=train_files,
output_dir=f"{tmpdir}/temp",
num_workers=1,
chunk_bytes="50MB",
)
Loading