From fe6e0262c7aab93da1cfeae8d59bde037c9d7c52 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 27 Jun 2024 13:12:19 +0530 Subject: [PATCH] Feat: Append data to pre-optimize dataset (#184) Co-authored-by: tchaton --- README.md | 47 ++++++++++ src/litdata/processing/data_processor.py | 11 ++- src/litdata/processing/functions.py | 32 ++++++- src/litdata/processing/utilities.py | 76 ++++++++++++++- src/litdata/streaming/cache.py | 7 +- src/litdata/streaming/resolver.py | 35 ++++++- src/litdata/streaming/writer.py | 19 +++- tests/processing/test_functions.py | 113 ++++++++++++++++++++++- tests/processing/test_utilities.py | 52 ++++++++++- 9 files changed, 376 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 97671b5b..236c9cfb 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ dataloader = StreamingDataLoader(dataset) - [Multi-GPU / Multi-Node Support](#multi-gpu--multi-node-support) - [Subsample and split your datasets](#subsample-and-split-your-datasets) +- [Append or Overwrite optimized datasets](#append-or-overwrite-optimized-datasets) - [Access any item](#access-any-item) - [Use any data transforms](#use-any-data-transforms) - [The Map Operator](#the-map-operator) @@ -177,6 +178,52 @@ print(len(dataset)) # display the length of your data # out: 1000 ``` +Or simply subsample them + +```python +from litdata import StreamingDataset, train_test_split + +dataset = StreamingDataset("s3://my-bucket/my-data", subsample=0.01) # data are stored in the cloud + +print(len(dataset)) # display the length of your data +# out: 1000 +``` + +## Append or overwrite optimized datasets + +LitData optimized datasets are assumed to be immutable. However, you can make the decision to modify them by changing the mode to either `append` or `overwrite`. + +```python +from litdata import optimize, StreamingDataset + +def compress(index): + return index, index**2 + +if __name__ == "__main__": + # Add some data + optimize( + fn=compress, + inputs=list(range(100)), + output_dir="./my_optimized_dataset", + chunk_bytes="64MB", + ) + + # Later on, you add more data + optimize( + fn=compress, + inputs=list(range(100, 200)), + output_dir="./my_optimized_dataset", + chunk_bytes="64MB", + mode="append", + ) + + ds = StreamingDataset("./my_optimized_dataset") + assert len(ds) == 200 + assert ds[:] == [(i, i**2) for i in range(200)] +``` + +The `overwrite` mode will delete the existing data and start from fresh. + ## Access any item Access the data you need, whenever you need it, regardless of where it is stored. diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 163f572d..8906fe2a 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -389,6 +389,7 @@ def __init__( num_uploaders: int, remove: bool, reader: Optional[BaseReader] = None, + writer_starting_chunk_index: int = 0, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -417,6 +418,7 @@ def __init__( self._counter = 0 self._last_time = time() self._index_counter = 0 + self.writer_starting_chunk_index = writer_starting_chunk_index def run(self) -> None: try: @@ -510,6 +512,7 @@ def _create_cache(self) -> None: chunk_bytes=self.data_recipe.chunk_bytes, chunk_size=self.data_recipe.chunk_size, compression=self.data_recipe.compression, + writer_chunk_index=self.writer_starting_chunk_index, ) self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index @@ -738,7 +741,8 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul merge_cache = Cache(cache_dir, chunk_bytes=1) node_rank = _get_node_rank() - merge_cache._merge_no_wait(node_rank if num_nodes > 1 else None) + merge_cache._merge_no_wait(node_rank if num_nodes > 1 else None, getattr(self, "existing_index", None)) + self._upload_index(output_dir, cache_dir, num_nodes, node_rank) if num_nodes == node_rank + 1: @@ -844,6 +848,7 @@ def __init__( reorder_files: bool = True, weights: Optional[List[int]] = None, reader: Optional[BaseReader] = None, + state_dict: Optional[Dict[int, int]] = None, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -862,6 +867,7 @@ def __init__( weights: Provide a list of weights associated to the inputs. This is used to evenly split the work among the workers. reader: Map the inputs to worker inputs and provides a read method to read a slice of the data. + state_dict: The writer state dict. This is used to decide how to append data to an existing dataset. """ self.input_dir = _resolve_dir(input_dir) @@ -881,6 +887,8 @@ def __init__( self.weights = weights self.reader = reader + self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} + if self.reader is not None and self.weights is not None: raise ValueError("Either the reader or the weights needs to be defined.") @@ -1061,6 +1069,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.num_uploaders, self.delete_cached_files, self.reader, + self.state_dict[worker_idx], ) worker.start() workers.append(worker) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 6a8562d6..c38db2b3 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -18,7 +18,7 @@ from functools import partial from pathlib import Path from types import FunctionType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from urllib import parse import torch @@ -26,7 +26,11 @@ from litdata.constants import _IS_IN_STUDIO from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader -from litdata.processing.utilities import optimize_dns_context +from litdata.processing.utilities import ( + extract_rank_and_index_from_filename, + optimize_dns_context, + read_index_file_content, +) from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.resolver import ( Dir, @@ -136,11 +140,13 @@ def __init__( chunk_size: Optional[int], chunk_bytes: Optional[Union[int, str]], compression: Optional[str], + existing_index: Optional[Dict[str, Any]] = None, ): super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) self._fn = fn self._inputs = inputs self.is_generator = False + self.existing_index = existing_index self.check_fn() @@ -292,6 +298,7 @@ def optimize( reorder_files: bool = True, reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, + mode: Optional[Literal["append", "overwrite"]] = None, ) -> None: """This function converts a dataset into chunks possibly in a distributed way. @@ -315,8 +322,13 @@ def optimize( reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. batch_size: Group the inputs into batches of batch_size length. + mode: The mode to use when writing the data. Accepts either ``append`` or ``overwrite`` or None. + Defaults to None. """ + if mode is not None and mode not in ["append", "overwrite"]: + raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") + if isinstance(inputs, StreamingDataLoader) and batch_size is not None: raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.") @@ -353,7 +365,7 @@ def optimize( " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file(_output_dir) + _assert_dir_has_index_file(_output_dir, mode=mode) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -366,6 +378,18 @@ def optimize( if num_workers == 0: num_workers = 1 + num_workers = num_workers or _get_default_num_workers() + state_dict = {rank: 0 for rank in range(num_workers)} + + existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None + + if existing_index_file_content is not None: + for chunk in existing_index_file_content["chunks"]: + rank, index = extract_rank_and_index_from_filename(chunk["filename"]) + + if rank < num_workers and state_dict[rank] <= index: + state_dict[rank] = index + 1 # +1 because we want to start from the next index + data_processor = DataProcessor( input_dir=resolved_dir, output_dir=_output_dir, @@ -375,6 +399,7 @@ def optimize( num_uploaders=num_uploaders, reorder_files=reorder_files, reader=reader, + state_dict=state_dict, ) with optimize_dns_context(True): @@ -385,6 +410,7 @@ def optimize( chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, + existing_index=existing_index_file_content, ) ) return None diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 84c18097..9f932882 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -12,13 +12,17 @@ # limitations under the License. import io +import json import os +import tempfile import urllib from contextlib import contextmanager from subprocess import DEVNULL, Popen -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from urllib import parse -from litdata.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE +from litdata.streaming.cache import Dir if _LIGHTNING_CLOUD_AVAILABLE: from lightning_cloud.openapi import ( @@ -27,6 +31,14 @@ from lightning_cloud.openapi.rest import ApiException from lightning_cloud.rest_client import LightningClient +try: + import boto3 + import botocore + + _BOTO3_AVAILABLE = True +except Exception: + _BOTO3_AVAILABLE = False + def _create_dataset( input_dir: Optional[str], @@ -177,3 +189,63 @@ def _get_work_dir() -> str: assert project_id is not None assert work_id is not None return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" + + +def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: + """Read the index file content.""" + if not isinstance(output_dir, Dir): + raise ValueError("The provided output_dir should be a Dir object.") + + if output_dir.url is None: + if output_dir.path is None: + return None + index_file_path = os.path.join(output_dir.path, _INDEX_FILENAME) + if not os.path.exists(index_file_path): + return None + with open(index_file_path) as f: + return json.load(f) + + else: + # download the index file from s3, and read it + obj = parse.urlparse(output_dir.url) + + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.") + + # TODO: Add support for all cloud providers + s3 = boto3.client("s3") + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + + # Check the index file exists + try: + # Create a temporary file + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file: + temp_file_name = temp_file.name + s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name) + # Read data from the temporary file + with open(temp_file_name) as temp_file: + data = json.load(temp_file) + # Delete the temporary file + os.remove(temp_file_name) + return data + except botocore.exceptions.ClientError: + return None + + +def extract_rank_and_index_from_filename(chunk_filename: str) -> Tuple[int, int]: + """Extract the rank and index from the filename. + + It is assumed that the filename is in the format `chunk--.bin` or + `chunk--.compressionAlgorithm.bin`. + + """ + # remove chunk and bin + chunk_filename = chunk_filename[6:-4].split("-") # (0, 0) or (0, 0.compressionAlgorithm) + assert len(chunk_filename) == 2 + + # get the rank and index + rank = int(chunk_filename[0]) + index = int(chunk_filename[1].split(".")[0]) + + return rank, index diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 896bba03..be5863b4 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -42,6 +42,7 @@ def __init__( item_loader: Optional[BaseItemLoader] = None, max_cache_size: Union[int, str] = "100GB", serializers: Optional[Dict[str, Serializer]] = None, + writer_chunk_index: Optional[int] = None, ): """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements together in order to accelerate fetching. @@ -56,6 +57,7 @@ def __init__( item_loader: The object responsible to generate the chunk intervals and load an item froma chunk. max_cache_size: The maximum cache size used by the reader when fetching the chunks. serializers: Provide your own serializers. + writer_chunk_index: The index of the chunk to start from when writing. """ super().__init__() @@ -68,6 +70,7 @@ def __init__( chunk_bytes=chunk_bytes, compression=compression, serializers=serializers, + chunk_index=writer_chunk_index or 0, ) self._reader = BinaryReader( self._cache_dir, @@ -137,9 +140,9 @@ def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: """Inform the writer the chunking phase is finished.""" self._writer.merge(num_workers, node_rank=node_rank) - def _merge_no_wait(self, node_rank: Optional[int] = None) -> None: + def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Optional[Dict[str, Any]] = None) -> None: """Inform the writer the chunking phase is finished.""" - self._writer._merge_no_wait(node_rank=node_rank) + self._writer._merge_no_wait(node_rank=node_rank, existing_index=existing_index) def __len__(self) -> int: return self._reader.get_length() diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 0b4ad832..95faa21d 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from urllib import parse from litdata.constants import _LIGHTNING_CLOUD_AVAILABLE @@ -246,11 +246,38 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool ) -def _assert_dir_has_index_file(output_dir: Dir) -> None: +def _assert_dir_has_index_file(output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None) -> None: + if mode is not None and mode not in ["append", "overwrite"]: + raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") + + if mode == "append": + # in append mode, we neither need to delete the index file nor the chunks + return + if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir isn't a Dir Object.") if output_dir.url is None: + # this is a local directory + assert output_dir.path + + if os.path.exists(output_dir.path): + # we need to delete the index file + index_file = os.path.join(output_dir.path, "index.json") + if os.path.exists(index_file) and mode is None: + raise RuntimeError( + f"The provided output_dir `{output_dir.path}` already contains an optimized immutable datasets." + " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" + " HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." + ) + + # delete index.json file and chunks + if os.path.exists(os.path.join(output_dir.path, "index.json")): + os.remove(os.path.join(output_dir.path, "index.json")) + for file in os.listdir(output_dir.path): + if file.endswith(".bin"): + os.remove(os.path.join(output_dir.path, file)) + return obj = parse.urlparse(output_dir.url) @@ -279,12 +306,14 @@ def _assert_dir_has_index_file(output_dir: Dir) -> None: except botocore.exceptions.ClientError: has_index_file = False - if has_index_file: + if has_index_file and mode is None: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains an optimized immutable datasets." " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" + " HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) + # Delete all the files (including the index file in overwrite mode) bucket_name = obj.netloc s3 = boto3.resource("s3") for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 68a12469..9967b95c 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -50,6 +50,7 @@ def __init__( compression: Optional[str] = None, follow_tensor_dimension: bool = True, serializers: Optional[Dict[str, Serializer]] = None, + chunk_index: Optional[int] = None, ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. @@ -59,6 +60,7 @@ def __init__( chunk_size: The maximum number of items within a chunk. compression: The compression algorithm to use. serializers: Provide your own serializers. + chunk_index: The index of the chunk to start from. """ self._cache_dir = cache_dir @@ -89,7 +91,7 @@ def __init__( self._compressor: Compressor = _COMPRESSORS[self._compression] self._serialized_items: Dict[int, Item] = {} - self._chunk_index = 0 + self._chunk_index = chunk_index or 0 self._min_index: Optional[int] = None self._max_index: Optional[int] = None self._chunks_info: List[Dict[str, Any]] = [] @@ -426,14 +428,25 @@ def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: self._merge_no_wait(node_rank=node_rank) - def _merge_no_wait(self, node_rank: Optional[int] = None) -> None: + def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Optional[Dict[str, Any]] = None) -> None: """Once all the workers have written their own index, the merge function is responsible to read and merge them - into a single index.""" + into a single index. + + Arguments: + node_rank: The node rank of the index file + existing_index: Existing index to be added to the newly created one. + + """ files = os.listdir(self._cache_dir) index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] chunks_info = [] config = None + + if existing_index is not None: + chunks_info.extend(existing_index["chunks"]) + config = existing_index["config"] + for index_filename in sorted(index_files): chunk_path = os.path.join(self._cache_dir, index_filename) with open(chunk_path) as f: diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 9f174cca..c7bb11b1 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -3,7 +3,7 @@ from unittest import mock import pytest -from litdata import walk +from litdata import StreamingDataset, optimize, walk from litdata.processing.functions import _get_input_dir, _resolve_dir @@ -43,3 +43,114 @@ def test_get_input_dir_with_s3_path(): input_dir = _resolve_dir(input_dir) assert not input_dir.path assert input_dir.url == "s3://my_bucket/my_folder" + + +def compress(index): + return index, index**2 + + +def different_compress(index): + return index, index**2, index**3 + + +@pytest.mark.skipif(sys.platform == "win32" and sys.platform == "darwin", reason="too slow") +def test_optimize_append_overwrite(tmpdir): + output_dir = str(tmpdir / "output_dir") + + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + with pytest.raises(RuntimeError, match="HINT: If you want to append/overwrite to the existing dataset"): + optimize( + fn=compress, + inputs=list(range(5, 10)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + ) + + with pytest.raises(ValueError, match="The provided `mode` should be either `append` or `overwrite`"): + optimize( + fn=compress, + inputs=list(range(5, 10)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + mode="some-random-mode", + ) + + optimize( + fn=compress, + inputs=list(range(5, 10)), + num_workers=2, + output_dir=output_dir, + chunk_bytes="64MB", + mode="overwrite", + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5, 10)] + + optimize( + fn=compress, + inputs=list(range(10, 15)), + num_workers=os.cpu_count(), + output_dir=output_dir, + chunk_bytes="64MB", + mode="append", + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 10 + assert ds[:] == [(i, i**2) for i in range(5, 15)] + + optimize( + fn=compress, + inputs=list(range(15, 20)), + num_workers=2, + output_dir=output_dir, + chunk_bytes="64MB", + mode="append", + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 15 + assert ds[:] == [(i, i**2) for i in range(5, 20)] + + with pytest.raises(Exception, match="The config isn't consistent between chunks"): + optimize( + fn=different_compress, + inputs=list(range(100, 200)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + mode="append", + ) + + optimize( + fn=different_compress, + inputs=list(range(0, 5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + mode="overwrite", + ) + + ds = StreamingDataset(output_dir) + + assert len(ds) == 5 + assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)] diff --git a/tests/processing/test_utilities.py b/tests/processing/test_utilities.py index 436e6063..e3196455 100644 --- a/tests/processing/test_utilities.py +++ b/tests/processing/test_utilities.py @@ -1,7 +1,13 @@ +import json from unittest.mock import MagicMock from litdata.processing import utilities as utilities_module -from litdata.processing.utilities import optimize_dns_context +from litdata.processing.utilities import ( + extract_rank_and_index_from_filename, + optimize_dns_context, + read_index_file_content, +) +from litdata.streaming.resolver import _resolve_dir def test_optimize_dns_context(monkeypatch): @@ -34,3 +40,47 @@ def readlines(self): " -c 'from litdata.processing.utilities import _optimize_dns; _optimize_dns(True)'" ) assert cmd == expected_cmd + + +def test_extract_rank_and_index_from_filename(): + file_names = [ + "chunk-0-0.bin", + "chunk-0-0.compressionAlgorithm.bin", + "chunk-1-4.bin", + "chunk-1-9.compressionAlgorithm.bin", + "chunk-22-10.bin", + "chunk-2-3.compressionAlgorithm.bin", + "chunk-31-3.bin", + "chunk-3-110.compressionAlgorithm.bin", + ] + + rank_and_index = [ + (0, 0), + (0, 0), + (1, 4), + (1, 9), + (22, 10), + (2, 3), + (31, 3), + (3, 110), + ] + + for idx, file_name in enumerate(file_names): + rank, index = extract_rank_and_index_from_filename(file_name) + assert rank == rank_and_index[idx][0] + assert index == rank_and_index[idx][1] + + +def test_read_index_file_content(tmpdir): + output_dir = tmpdir / "output_dir" + + assert read_index_file_content(_resolve_dir(str(output_dir))) is None + + output_dir.mkdir() + assert read_index_file_content(_resolve_dir(str(output_dir))) is None + + with open(output_dir / "index.json", "w") as f: + dummy_dict = {"chunks": ["abc.bin", "def.bin"], "config": {"data_format": "a", "data_spec": "b"}} + json.dump(dummy_dict, f) + + assert read_index_file_content(_resolve_dir(str(output_dir))) == dummy_dict