diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index f1af9afa..20427e23 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -456,6 +456,7 @@ def run(self) -> None: try: self._setup() self._loop() + self._terminate() except Exception: traceback_format = traceback.format_exc() self.error_queue.put(traceback_format) @@ -469,6 +470,19 @@ def _setup(self) -> None: self._start_uploaders() self._start_remover() + def _terminate(self) -> None: + """Make sure all the uploaders, downloaders and removers are terminated.""" + for uploader in self.uploaders: + if uploader.is_alive(): + uploader.join() + + for downloader in self.downloaders: + if downloader.is_alive(): + downloader.join() + + if self.remover and self.remover.is_alive(): + self.remover.join() + def _loop(self) -> None: num_downloader_finished = 0 @@ -795,7 +809,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")] if chunks and delete_cached_files and output_dir.path is not None: - raise RuntimeError(f"All the chunks should have been deleted. Found {chunks}") + raise RuntimeError(f"All the chunks should have been deleted. Found {chunks} in cache: {cache_dir}") merge_cache = Cache(cache_dir, chunk_bytes=1) node_rank = _get_node_rank() @@ -1110,6 +1124,10 @@ def run(self, data_recipe: DataRecipe) -> None: current_total = new_total if current_total == num_items: + # make sure all processes are terminated + for w in self.workers: + if w.is_alive(): + w.join() break if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS: @@ -1118,17 +1136,13 @@ def run(self, data_recipe: DataRecipe) -> None: # Exit early if all the workers are done. # This means there were some kinda of errors. + # TODO: Check whether this is still required. if all(not w.is_alive() for w in self.workers): raise RuntimeError("One of the worker has failed") if _TQDM_AVAILABLE: pbar.close() - # TODO: Check whether this is still required. - if num_nodes == 1: - for w in self.workers: - w.join() - print("Workers are finished.") result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir) diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 80eec0ba..42482b7e 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -1,10 +1,15 @@ +import glob import os +import random +import shutil import sys +from pathlib import Path from unittest import mock import cryptography import numpy as np import pytest +import requests from litdata import StreamingDataset, merge_datasets, optimize, walk from litdata.processing.functions import _get_input_dir, _resolve_dir from litdata.streaming.cache import Cache @@ -475,3 +480,53 @@ def test_optimize_with_rsa_encryption(tmpdir): # encryption=rsa, # mode="overwrite", # ) + + +def tokenize(filename: str): + with open(filename, encoding="utf-8") as file: + text = file.read() + text = text.strip().split(" ") + word_to_int = {word: random.randint(1, 1000) for word in set(text)} # noqa: S311 + tokenized = [word_to_int[word] for word in text] + + yield tokenized + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") +def test_optimize_race_condition(tmpdir): + # issue: https://github.com/Lightning-AI/litdata/issues/367 + # run_commands = [ + # "mkdir -p tempdir/custom_texts", + # "curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output tempdir/custom_texts/book1.txt", + # "curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output tempdir/custom_texts/book2.txt", + # ] + shutil.rmtree(f"{tmpdir}/custom_texts", ignore_errors=True) + os.makedirs(f"{tmpdir}/custom_texts", exist_ok=True) + + urls = [ + "https://www.gutenberg.org/cache/epub/24440/pg24440.txt", + "https://www.gutenberg.org/cache/epub/26393/pg26393.txt", + ] + + for i, url in enumerate(urls): + print(f"downloading {i+1} file") + with requests.get(url, stream=True, timeout=10) as r: + r.raise_for_status() # Raise an exception for bad status codes + + with open(f"{tmpdir}/custom_texts/book{i+1}.txt", "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + print("=" * 100) + + train_files = sorted(glob.glob(str(Path(f"{tmpdir}/custom_texts") / "*.txt"))) + print("=" * 100) + print(train_files) + print("=" * 100) + optimize( + fn=tokenize, + inputs=train_files, + output_dir=f"{tmpdir}/temp", + num_workers=1, + chunk_bytes="50MB", + )