diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 30ef688e..d44662ae 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -21,6 +21,7 @@ import tempfile import traceback from abc import abstractmethod +from contextlib import suppress from dataclasses import dataclass from multiprocessing import Process, Queue from pathlib import Path @@ -42,7 +43,7 @@ _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader -from litdata.processing.utilities import _create_dataset +from litdata.processing.utilities import _create_dataset, download_directory_from_S3, remove_uuid_from_filename from litdata.streaming import Cache from litdata.streaming.cache import Dir from litdata.streaming.client import S3Client @@ -58,8 +59,8 @@ if _LIGHTNING_CLOUD_AVAILABLE: from lightning_cloud.openapi import V1DatasetType - if _BOTO3_AVAILABLE: + import boto3 import botocore logger = logging.Logger(__name__) @@ -229,10 +230,16 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ if obj.scheme == "s3": try: + output_filepath = str(obj.path).lstrip("/") + + if local_filepath.__contains__(".checkpoints"): + output_filepath = os.path.join(output_filepath, ".checkpoints") if tmpdir is None: - output_filepath = os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)) + output_filepath = os.path.join(output_filepath, os.path.basename(local_filepath)) else: - output_filepath = os.path.join(str(obj.path).lstrip("/"), local_filepath.replace(tmpdir, "")[1:]) + output_filepath = os.path.join(output_filepath, local_filepath.replace(tmpdir, "")[1:]) + + output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints s3.client.upload_file( local_filepath, @@ -243,10 +250,17 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ print(e) elif output_dir.path: + output_filepath = output_dir.path + + if local_filepath.__contains__(".checkpoints"): + output_filepath = os.path.join(output_filepath, ".checkpoints") + if tmpdir is None: - output_filepath = os.path.join(output_dir.path, os.path.basename(local_filepath)) + output_filepath = os.path.join(output_filepath, os.path.basename(local_filepath)) else: - output_filepath = os.path.join(output_dir.path, local_filepath.replace(tmpdir, "")[1:]) + output_filepath = os.path.join(output_filepath, local_filepath.replace(tmpdir, "")[1:]) + + output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints os.makedirs(os.path.dirname(output_filepath), exist_ok=True) shutil.move(local_filepath, output_filepath) @@ -388,6 +402,9 @@ def __init__( remove: bool, reader: Optional[BaseReader] = None, writer_starting_chunk_index: int = 0, + use_checkpoint: bool = False, + checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None, + checkpoint_next_index: Optional[int] = None, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -416,7 +433,10 @@ def __init__( self._counter = 0 self._last_time = time() self._index_counter = 0 - self.writer_starting_chunk_index = writer_starting_chunk_index + self.writer_starting_chunk_index: int = writer_starting_chunk_index + self.use_checkpoint: bool = use_checkpoint + self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info + self.checkpoint_next_index: Optional[int] = checkpoint_next_index def run(self) -> None: try: @@ -424,7 +444,6 @@ def run(self) -> None: self._loop() except Exception: traceback_format = traceback.format_exc() - print(traceback_format) self.error_queue.put(traceback_format) print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is done.") @@ -500,6 +519,11 @@ def _create_cache(self) -> None: os.makedirs(self.cache_data_dir, exist_ok=True) self.cache_chunks_dir = _get_cache_dir() + + if os.path.exists(self.cache_chunks_dir): + # clean up the cache chunks dir folder to avoid previous json files from interfering with the current run + shutil.rmtree(self.cache_chunks_dir, ignore_errors=True) + os.makedirs(self.cache_chunks_dir, exist_ok=True) if isinstance(self.data_recipe, DataTransformRecipe): @@ -514,6 +538,19 @@ def _create_cache(self) -> None: ) self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index + # return + if self.use_checkpoint and all( + [ + self.checkpoint_chunks_info is not None, + self.checkpoint_next_index is not None, + ] + ): + assert isinstance(self.checkpoint_next_index, int) + assert isinstance(self.checkpoint_chunks_info, list) + + self.cache._writer._chunks_info = self.checkpoint_chunks_info + self.cache._writer._chunk_index += self.checkpoint_next_index + def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None: if not data or (self.output_dir.url if self.output_dir.url else self.output_dir.path) is None: return @@ -636,8 +673,11 @@ def _handle_data_chunk_recipe(self, index: int) -> None: chunk_filepath = self.cache._add_item(self._index_counter, item_data_or_generator) self._try_upload(chunk_filepath) self._index_counter += 1 + if self.use_checkpoint: + checkpoint_filepath = self.cache.save_checkpoint() + self._try_upload(checkpoint_filepath) except Exception as e: - raise RuntimeError(f"Failed processing {self.items[index]}") from e + raise RuntimeError(f"Failed processing {self.items[index]=}; {index=}") from e def _handle_data_chunk_recipe_end(self) -> None: chunks_filepaths = self.cache.done() @@ -647,6 +687,10 @@ def _handle_data_chunk_recipe_end(self) -> None: if isinstance(chunk_filepath, str) and os.path.exists(chunk_filepath): self.to_upload_queues[i % self.num_uploaders].put(chunk_filepath) + if self.use_checkpoint and not self.data_recipe.is_generator: + checkpoint_filepath = self.cache.save_checkpoint() + self._try_upload(checkpoint_filepath) + def _handle_data_transform_recipe(self, index: int) -> None: # Don't use a context manager to avoid deleting files that are being uploaded. output_dir = tempfile.mkdtemp() @@ -847,6 +891,7 @@ def __init__( weights: Optional[List[int]] = None, reader: Optional[BaseReader] = None, state_dict: Optional[Dict[int, int]] = None, + use_checkpoint: bool = False, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -866,6 +911,8 @@ def __init__( This is used to evenly split the work among the workers. reader: Map the inputs to worker inputs and provides a read method to read a slice of the data. state_dict: The writer state dict. This is used to decide how to append data to an existing dataset. + use_checkpoint: Whether to create checkpoints while processing the data, which can be used to resume the + processing from the last checkpoint if the process is interrupted. (`Default: False`) """ self.input_dir = _resolve_dir(input_dir) @@ -884,6 +931,9 @@ def __init__( self.reorder_files = reorder_files self.weights = weights self.reader = reader + self.use_checkpoint = use_checkpoint + self.checkpoint_chunks_info: Optional[List[List[Dict[str, Any]]]] = None + self.checkpoint_next_index: Optional[List[int]] = None self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} @@ -905,6 +955,10 @@ def run(self, data_recipe: DataRecipe) -> None: if not isinstance(data_recipe, DataRecipe): raise ValueError("The provided value should be a data recipe.") + if not self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): + # clean up checkpoints if not using checkpoints + self._cleanup_checkpoints() + t0 = time() print(f"Setup started with fast_dev_run={self.fast_dev_run}.") @@ -943,8 +997,28 @@ def run(self, data_recipe: DataRecipe) -> None: print(f"Setup finished in {round(time() - t0, 3)} seconds. Found {len(user_items)} items to process.") + if self.use_checkpoint: + if hasattr(data_recipe, "is_generator") and data_recipe.is_generator: + # Checkpoint feature is not supported for generators for now. + raise ValueError("Checkpoint feature is not supported for generators, yet.") + # get the last checkpoint details + print("Resuming from last saved checkpoint...") + self._load_checkpoint_config(workers_user_items) + + assert isinstance(self.checkpoint_next_index, list) + + if all(self.checkpoint_next_index[i] == 0 for i in range(self.num_workers)): + # save the current configuration in the checkpoints.json file + print("No checkpoints found. Saving current configuration...") + self._save_current_config(workers_user_items) + else: + # load the last checkpoint details + assert isinstance(self.checkpoint_next_index, list) + workers_user_items = [w[self.checkpoint_next_index[i] :] for i, w in enumerate(workers_user_items)] + print("Checkpoints loaded successfully.") + if self.fast_dev_run: - items_to_keep = self.fast_dev_run if type(self.fast_dev_run) is int else _DEFAULT_FAST_DEV_RUN_ITEMS + items_to_keep = self.fast_dev_run if isinstance(self.fast_dev_run, int) else _DEFAULT_FAST_DEV_RUN_ITEMS workers_user_items = [w[:items_to_keep] for w in workers_user_items] print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.") @@ -1040,10 +1114,14 @@ def run(self, data_recipe: DataRecipe) -> None: ) print("Finished data processing!") + if self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): + # clean up checkpoints + self._cleanup_checkpoints() def _exit_on_error(self, error: str) -> None: for w in self.workers: - w.join(0) + # w.join(0) + w.terminate() # already error has occurred. So, no benefit of processing further. raise RuntimeError(f"We found the following error {error}.") def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: List[List[Any]]) -> None: @@ -1068,6 +1146,9 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.delete_cached_files, self.reader, self.state_dict[worker_idx], + self.use_checkpoint, + self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None, + self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None, ) worker.start() workers.append(worker) @@ -1100,3 +1181,173 @@ def _cleanup_cache(self) -> None: shutil.rmtree(cache_data_dir, ignore_errors=True) os.makedirs(cache_data_dir, exist_ok=True) + + def _cleanup_checkpoints(self) -> None: + if not isinstance(self.output_dir, Dir): + raise ValueError("The provided output_dir isn't a Dir Object.") + + if self.output_dir.url is None: + # this is a local directory + if self.output_dir.path is None: + return + + if os.path.exists(self.output_dir.path): + # clear the checkpoints + with suppress(FileNotFoundError): + shutil.rmtree(os.path.join(self.output_dir.path, ".checkpoints")) + + return + + obj = parse.urlparse(self.output_dir.url) + + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") + + s3 = boto3.client("s3") + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + + # Delete all the files (including the index file in overwrite mode) + bucket_name = obj.netloc + s3 = boto3.resource("s3") + + checkpoint_prefix = os.path.join(prefix, ".checkpoints") + + for obj in s3.Bucket(bucket_name).objects.filter(Prefix=checkpoint_prefix): + s3.Object(bucket_name, obj.key).delete() + + def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: + if not self.use_checkpoint: + return + + # save the current configuration in the config.json file + config = { + "num_workers": self.num_workers, + "workers_user_items": workers_user_items, + } + + try: + if self.output_dir.url is None: + assert self.output_dir.path + + if not os.path.exists(os.path.join(self.output_dir.path, ".checkpoints")): + os.makedirs(os.path.join(self.output_dir.path, ".checkpoints")) + + with open(os.path.join(self.output_dir.path, ".checkpoints", "config.json"), "w") as f: + json.dump(config, f) + + return + + obj = parse.urlparse(self.output_dir.url) + + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") + + # TODO: Add support for all cloud providers + + s3 = S3Client() + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" + + # write config.json file to temp directory and upload it to s3 + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_name = os.path.join(temp_dir, "config.json") + with open(temp_file_name, "w") as f: + json.dump(config, f) + s3.client.upload_file( + temp_file_name, + obj.netloc, + os.path.join(prefix, "config.json"), + ) + except Exception as e: + print(e) + + def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: + if not self.use_checkpoint: + return + + default_chunk_info: List[Dict[str, Any]] = [] + + self.checkpoint_chunks_info = [default_chunk_info for _ in range(self.num_workers)] + self.checkpoint_next_index = [0 for _ in range(self.num_workers)] + + if self.output_dir.url is None: + assert self.output_dir.path + + if not os.path.exists(os.path.join(self.output_dir.path, ".checkpoints")): + return + + if not os.path.exists(os.path.join(self.output_dir.path, ".checkpoints", "config.json")): + # if the config.json file doesn't exist, we don't have any checkpoint saved + return + + with open(os.path.join(self.output_dir.path, ".checkpoints", "config.json")) as f: + config = json.load(f) + + if config["num_workers"] != self.num_workers: + raise ValueError( + "The number of workers in the checkpoints doesn't match the current number of workers." + ) + + if config["workers_user_items"] != workers_user_items: + raise ValueError("Existing checkpoints are not compatible with the current configuration.") + + checkpoint_file_names = [f"checkpoint-{worker_idx}.json" for worker_idx in range(self.num_workers)] + + for i, checkpoint_file_name in enumerate(checkpoint_file_names): + if not os.path.exists(os.path.join(self.output_dir.path, ".checkpoints", checkpoint_file_name)): + # if the checkpoint file doesn't exist, we don't have any checkpoint saved for this worker + continue + + with open(os.path.join(self.output_dir.path, ".checkpoints", checkpoint_file_name)) as f: + checkpoint = json.load(f) + + self.checkpoint_chunks_info[i] = checkpoint["chunks"] + self.checkpoint_next_index[i] = checkpoint["done_till_index"] + return + + obj = parse.urlparse(self.output_dir.url) + + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") + + # TODO: Add support for all cloud providers + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" + + # Delete all the files (including the index file in overwrite mode) + bucket_name = obj.netloc + + # download all the checkpoint files in tempdir and read them + with tempfile.TemporaryDirectory() as temp_dir: + saved_file_dir = download_directory_from_S3(bucket_name, prefix, temp_dir) + + if not os.path.exists(os.path.join(saved_file_dir, "config.json")): + # if the config.json file doesn't exist, we don't have any checkpoint saved + return + + # read the config.json file + with open(os.path.join(saved_file_dir, "config.json")) as f: + config = json.load(f) + + if config["num_workers"] != self.num_workers: + raise ValueError( + "The number of workers in the checkpoints doesn't match the current number of workers." + ) + + if config["workers_user_items"] != workers_user_items: + raise ValueError("Existing checkpoints are not compatible with the current configuration.") + + checkpoint_file_names = [f"checkpoint-{worker_idx}.json" for worker_idx in range(self.num_workers)] + + for i, checkpoint_file_name in enumerate(checkpoint_file_names): + if not os.path.exists(os.path.join(saved_file_dir, checkpoint_file_name)): + # if the checkpoint file doesn't exist, we don't have any checkpoint saved for this worker + continue + + with open(os.path.join(saved_file_dir, checkpoint_file_name)) as f: + checkpoint = json.load(f) + + self.checkpoint_chunks_info[i] = checkpoint["chunks"] + self.checkpoint_next_index[i] = checkpoint["done_till_index"] + return diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 6fe3199c..631a8176 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -311,6 +311,7 @@ def optimize( reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, mode: Optional[Literal["append", "overwrite"]] = None, + use_checkpoint: bool = False, ) -> None: """This function converts a dataset into chunks possibly in a distributed way. @@ -336,6 +337,8 @@ def optimize( batch_size: Group the inputs into batches of batch_size length. mode: The mode to use when writing the data. Accepts either ``append`` or ``overwrite`` or None. Defaults to None. + use_checkpoint: Whether to create checkpoints while processing the data, which can be used to resume the + processing from the last checkpoint if the process is interrupted. (`Default: False`) """ if mode is not None and mode not in ["append", "overwrite"]: @@ -377,7 +380,7 @@ def optimize( " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file(_output_dir, mode=mode) + _assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -412,6 +415,7 @@ def optimize( reorder_files=reorder_files, reader=reader, state_dict=state_dict, + use_checkpoint=use_checkpoint, ) with optimize_dns_context(True): diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 9f932882..8312ea39 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -249,3 +249,37 @@ def extract_rank_and_index_from_filename(chunk_filename: str) -> Tuple[int, int] index = int(chunk_filename[1].split(".")[0]) return rank, index + + +def remove_uuid_from_filename(filepath: str) -> str: + """Remove the unique id from the filepath. Expects the filepath to be in the format + `checkpoint--.json`. + + e.g.: `checkpoint-0-9fe2c4e93f654fdbb24c02b15259716c.json` + -> `checkpoint-0.json` + + """ + + if not filepath.__contains__(".checkpoints"): + return filepath + + # uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character + return filepath[:-38] + ".json" + + +def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str: + s3_resource = boto3.resource("s3") + bucket = s3_resource.Bucket(bucket_name) + + saved_file_dir = "." + + for obj in bucket.objects.filter(Prefix=remote_directory_name): + local_filename = os.path.join(local_directory_name, obj.key) + + if not os.path.exists(os.path.dirname(local_filename)): + os.makedirs(os.path.dirname(local_filename)) + with open(local_filename, "wb") as f: + s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f) + saved_file_dir = os.path.dirname(local_filename) + + return saved_file_dir diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index c78440af..d303066e 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -152,3 +152,7 @@ def get_chunk_intervals(self) -> List[Interval]: def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]: return self._reader._get_chunk_index_from_index(index) + + def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]: + """Save the current state of the writer to a checkpoint.""" + return self._writer.save_checkpoint(checkpoint_dir=checkpoint_dir) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 95faa21d..633ebc05 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -14,7 +14,9 @@ import datetime import os import re +import shutil import sys +from contextlib import suppress from dataclasses import dataclass from pathlib import Path from time import sleep @@ -246,7 +248,9 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool ) -def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None) -> None: +def _assert_dir_has_index_file( + output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False +) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") @@ -273,10 +277,17 @@ def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append", # delete index.json file and chunks if os.path.exists(os.path.join(output_dir.path, "index.json")): + # only possible if mode = "overwrite" os.remove(os.path.join(output_dir.path, "index.json")) - for file in os.listdir(output_dir.path): - if file.endswith(".bin"): - os.remove(os.path.join(output_dir.path, file)) + + if mode == "overwrite" or (mode is None and not use_checkpoint): + for file in os.listdir(output_dir.path): + if file.endswith(".bin"): + os.remove(os.path.join(output_dir.path, file)) + + # delete checkpoints + with suppress(FileNotFoundError): + shutil.rmtree(os.path.join(output_dir.path, ".checkpoints")) return @@ -316,8 +327,10 @@ def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append", # Delete all the files (including the index file in overwrite mode) bucket_name = obj.netloc s3 = boto3.resource("s3") - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): - s3.Object(bucket_name, obj.key).delete() + + if mode == "overwrite" or (mode is None and not use_checkpoint): + for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): + s3.Object(bucket_name, obj.key).delete() def _get_lightning_cloud_url() -> str: diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 9967b95c..dec72d0f 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -13,6 +13,7 @@ import json import os +import uuid import warnings from dataclasses import dataclass from time import sleep @@ -64,7 +65,7 @@ def __init__( """ self._cache_dir = cache_dir - + os.makedirs(self._cache_dir, exist_ok=True) if (isinstance(self._cache_dir, str) and not os.path.exists(self._cache_dir)) or self._cache_dir is None: raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.") @@ -103,6 +104,7 @@ def __init__( self._per_sample_num_bytes = 0 self._per_sample_num_items = 0 + self.last_checkpoint_chunk_info: List[Dict[str, Any]] = [] @property def filled(self) -> bool: @@ -458,7 +460,7 @@ def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Option elif config != data["config"]: raise Exception( "The config isn't consistent between chunks. This shouldn't have happened." - f"Found {config} {data['config']}." + f"Found {config}; {data['config']}." ) chunks_info.extend(data["chunks"]) @@ -494,3 +496,25 @@ def _pretty_serialized_items(self) -> Dict[int, Item]: data=b"", ) return out + + def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]: + """Save the current state of the writer to a checkpoint.""" + checkpoint_dir = os.path.join(self._cache_dir, checkpoint_dir) + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + if self._chunks_info == self.last_checkpoint_chunk_info: + # to avoid saving the same checkpoint twice + return None + + unique_id = uuid.uuid4().hex + done_till_index = sum(chnk_info["chunk_size"] for chnk_info in self._chunks_info) + + checkpoint_filepath = os.path.join(checkpoint_dir, f"checkpoint-{self.rank}-{unique_id}.json") + + checkPoint = {"chunks": self._chunks_info, "config": self.get_config(), "done_till_index": done_till_index} + + with open(checkpoint_filepath, "w") as f: + json.dump(checkPoint, f) + + return checkpoint_filepath diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index adabf76e..5bce87b7 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -54,7 +54,17 @@ def different_compress(index): return index, index**2, index**3 -@pytest.mark.skipif(sys.platform == "win32" and sys.platform == "darwin", reason="too slow") +def fn(i: int): + if i in [1, 2, 4]: + raise ValueError("An error occurred") + return i, i**2 + + +def another_fn(i: int): + return i, i**2 + + +@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow") def test_optimize_append_overwrite(tmpdir): output_dir = str(tmpdir / "output_dir") @@ -157,6 +167,80 @@ def test_optimize_append_overwrite(tmpdir): assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)] +@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow") +def test_optimize_checkpoint_in_none_and_append_mode(tmpdir): + output_dir = str(tmpdir / "output_dir") + + with pytest.raises(RuntimeError, match="We found the following error"): + optimize( + fn=fn, + inputs=list(range(4)), + output_dir=output_dir, + chunk_size=1, + num_workers=2, + use_checkpoint=True, + ) + + # check that the checkpoints are created + assert os.path.exists(os.path.join(output_dir, ".checkpoints")) + assert os.path.exists(os.path.join(output_dir, ".checkpoints", "config.json")) + + optimize( + fn=another_fn, + inputs=list(range(4)), + output_dir=output_dir, + chunk_size=1, + num_workers=2, + use_checkpoint=True, + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 4 + assert ds[:] == [(i, i**2) for i in range(4)] + # checkpoints should be deleted + assert not os.path.exists(os.path.join(output_dir, ".checkpoints")) + + # --------- now test for append mode --------- + + with pytest.raises(RuntimeError, match="We found the following error"): + optimize( + fn=fn, + inputs=list(range(4, 8)), + output_dir=output_dir, + chunk_size=1, + num_workers=2, + use_checkpoint=True, + mode="append", + ) + + # check that the checkpoints are created + assert os.path.exists(os.path.join(output_dir, ".checkpoints")) + assert os.path.exists(os.path.join(output_dir, ".checkpoints", "config.json")) + print("-" * 80) + # print all the files in the checkpoints folder + for f in os.listdir(os.path.join(output_dir, ".checkpoints")): + print(f) + print("-" * 80) + + optimize( + fn=another_fn, + inputs=list(range(4, 8)), + output_dir=output_dir, + chunk_size=1, + num_workers=2, + use_checkpoint=True, + mode="append", + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 8 + assert ds[:] == [(i, i**2) for i in range(8)] + # checkpoints should be deleted + assert not os.path.exists(os.path.join(output_dir, ".checkpoints")) + + def test_merge_datasets(tmpdir): folder_1 = os.path.join(tmpdir, "folder_1") folder_2 = os.path.join(tmpdir, "folder_2") diff --git a/tests/processing/test_utilities.py b/tests/processing/test_utilities.py index e3196455..411779eb 100644 --- a/tests/processing/test_utilities.py +++ b/tests/processing/test_utilities.py @@ -6,6 +6,7 @@ extract_rank_and_index_from_filename, optimize_dns_context, read_index_file_content, + remove_uuid_from_filename, ) from litdata.streaming.resolver import _resolve_dir @@ -84,3 +85,28 @@ def test_read_index_file_content(tmpdir): json.dump(dummy_dict, f) assert read_index_file_content(_resolve_dir(str(output_dir))) == dummy_dict + + +def test_remove_uuid_from_filename(): + filepaths = [ + "checkpoint-0-9fe2c4e93f654fdbb24c02b15259716c.json", + "checkpoint-1-9fe2c4e93f654fdbb24c02b15259716c.json", + "checkpoint-2-9fe2c4e93f654fdbb24c02b15259716c.json", + "checkpoint-101-9fe2c4e93f654fdbb24c02b15259716c.json", + "checkpoint-12-9fe2c4e93f654fdbb24c02b15259716c.json", + "checkpoint-267-9fe2c4e93f654fdbb24c02b15259716c.json", + ] + + expected = [ + "checkpoint-0.json", + "checkpoint-1.json", + "checkpoint-2.json", + "checkpoint-101.json", + "checkpoint-12.json", + "checkpoint-267.json", + ] + + for idx, filepath in enumerate(filepaths): + filepath = ".checkpoints/" + filepath + result = remove_uuid_from_filename(filepath) + assert result == ".checkpoints/" + expected[idx] diff --git a/tests/streaming/test_cache.py b/tests/streaming/test_cache.py index 9bb97e65..ae338c88 100644 --- a/tests/streaming/test_cache.py +++ b/tests/streaming/test_cache.py @@ -320,3 +320,23 @@ def test_cache_for_text_tokens(tmpdir): with pytest.raises(ValueError, match="TokensLoader"): len(Cache(str(tmpdir), chunk_size=block_size * 11)) + + +def test_cache_checkpoint(tmpdir): + cache_dir = os.path.join(tmpdir, "cache_checkpoint") + os.makedirs(cache_dir) + + cache = Cache(cache_dir, chunk_bytes=90) + + # you encode data + for i in range(100): + cache[i] = i + + # I am done, write the index ... + cache.done() + cache.merge() + cache.save_checkpoint() + + for file in os.listdir(os.path.join(cache_dir, ".checkpoints")): + assert file.__contains__("checkpoint-0") + assert file.endswith(".json") diff --git a/tests/streaming/test_writer.py b/tests/streaming/test_writer.py index 7d377ae9..04ae209f 100644 --- a/tests/streaming/test_writer.py +++ b/tests/streaming/test_writer.py @@ -37,9 +37,6 @@ def seed_everything(random_seed): def test_binary_writer_with_ints_and_chunk_bytes(tmpdir): - with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): - BinaryWriter("dontexists", {}) - match = ( "The provided compression something_else isn't available" if _ZSTD_AVAILABLE @@ -81,9 +78,6 @@ def test_binary_writer_with_ints_and_chunk_bytes(tmpdir): def test_binary_writer_with_ints_and_chunk_size(tmpdir): seed_everything(42) - with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): - BinaryWriter("dontexists", {}) - match = ( "The provided compression something_else isn't available" if _ZSTD_AVAILABLE @@ -252,3 +246,23 @@ def test_writer_unordered_indexes(tmpdir): assert data["chunks"][0]["chunk_size"] == 5 assert data["chunks"][1]["chunk_size"] == 5 assert data["chunks"][2]["chunk_size"] == 2 + + +def test_writer_save_checkpoint(tmpdir): + cache_dir = os.path.join(tmpdir, "chunks") + os.makedirs(cache_dir, exist_ok=True) + + binary_writer = BinaryWriter(cache_dir, chunk_size=5) + + arr = [2, 3, 1, 4, 6, 5, 7, 8, 11, 9, 10, 12] + + for i in arr: + binary_writer[i] = i - 1 + + binary_writer.done() + binary_writer.merge() + binary_writer.save_checkpoint() + + for file in os.listdir(os.path.join(cache_dir, ".checkpoints")): + assert file.__contains__("checkpoint-0") + assert file.endswith(".json")