Skip to content

Commit

Permalink
Feat: Append data to pre-optimize dataset (#184)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
deependujha and tchaton authored Jun 27, 2024
1 parent a8f33df commit fe6e026
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 16 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ dataloader = StreamingDataLoader(dataset)

- [Multi-GPU / Multi-Node Support](#multi-gpu--multi-node-support)
- [Subsample and split your datasets](#subsample-and-split-your-datasets)
- [Append or Overwrite optimized datasets](#append-or-overwrite-optimized-datasets)
- [Access any item](#access-any-item)
- [Use any data transforms](#use-any-data-transforms)
- [The Map Operator](#the-map-operator)
Expand Down Expand Up @@ -177,6 +178,52 @@ print(len(dataset)) # display the length of your data
# out: 1000
```

Or simply subsample them

```python
from litdata import StreamingDataset, train_test_split

dataset = StreamingDataset("s3://my-bucket/my-data", subsample=0.01) # data are stored in the cloud

print(len(dataset)) # display the length of your data
# out: 1000
```

## Append or overwrite optimized datasets

LitData optimized datasets are assumed to be immutable. However, you can make the decision to modify them by changing the mode to either `append` or `overwrite`.

```python
from litdata import optimize, StreamingDataset

def compress(index):
return index, index**2

if __name__ == "__main__":
# Add some data
optimize(
fn=compress,
inputs=list(range(100)),
output_dir="./my_optimized_dataset",
chunk_bytes="64MB",
)

# Later on, you add more data
optimize(
fn=compress,
inputs=list(range(100, 200)),
output_dir="./my_optimized_dataset",
chunk_bytes="64MB",
mode="append",
)

ds = StreamingDataset("./my_optimized_dataset")
assert len(ds) == 200
assert ds[:] == [(i, i**2) for i in range(200)]
```

The `overwrite` mode will delete the existing data and start from fresh.

## Access any item

Access the data you need, whenever you need it, regardless of where it is stored.
Expand Down
11 changes: 10 additions & 1 deletion src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(
num_uploaders: int,
remove: bool,
reader: Optional[BaseReader] = None,
writer_starting_chunk_index: int = 0,
) -> None:
"""The BaseWorker is responsible to process the user data."""
self.worker_index = worker_index
Expand Down Expand Up @@ -417,6 +418,7 @@ def __init__(
self._counter = 0
self._last_time = time()
self._index_counter = 0
self.writer_starting_chunk_index = writer_starting_chunk_index

def run(self) -> None:
try:
Expand Down Expand Up @@ -510,6 +512,7 @@ def _create_cache(self) -> None:
chunk_bytes=self.data_recipe.chunk_bytes,
chunk_size=self.data_recipe.chunk_size,
compression=self.data_recipe.compression,
writer_chunk_index=self.writer_starting_chunk_index,
)
self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index

Expand Down Expand Up @@ -738,7 +741,8 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul

merge_cache = Cache(cache_dir, chunk_bytes=1)
node_rank = _get_node_rank()
merge_cache._merge_no_wait(node_rank if num_nodes > 1 else None)
merge_cache._merge_no_wait(node_rank if num_nodes > 1 else None, getattr(self, "existing_index", None))

self._upload_index(output_dir, cache_dir, num_nodes, node_rank)

if num_nodes == node_rank + 1:
Expand Down Expand Up @@ -844,6 +848,7 @@ def __init__(
reorder_files: bool = True,
weights: Optional[List[int]] = None,
reader: Optional[BaseReader] = None,
state_dict: Optional[Dict[int, int]] = None,
):
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
training faster.
Expand All @@ -862,6 +867,7 @@ def __init__(
weights: Provide a list of weights associated to the inputs.
This is used to evenly split the work among the workers.
reader: Map the inputs to worker inputs and provides a read method to read a slice of the data.
state_dict: The writer state dict. This is used to decide how to append data to an existing dataset.
"""
self.input_dir = _resolve_dir(input_dir)
Expand All @@ -881,6 +887,8 @@ def __init__(
self.weights = weights
self.reader = reader

self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)}

if self.reader is not None and self.weights is not None:
raise ValueError("Either the reader or the weights needs to be defined.")

Expand Down Expand Up @@ -1061,6 +1069,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
self.num_uploaders,
self.delete_cached_files,
self.reader,
self.state_dict[worker_idx],
)
worker.start()
workers.append(worker)
Expand Down
32 changes: 29 additions & 3 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
from functools import partial
from pathlib import Path
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
from urllib import parse

import torch

from litdata.constants import _IS_IN_STUDIO
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import optimize_dns_context
from litdata.processing.utilities import (
extract_rank_and_index_from_filename,
optimize_dns_context,
read_index_file_content,
)
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.resolver import (
Dir,
Expand Down Expand Up @@ -136,11 +140,13 @@ def __init__(
chunk_size: Optional[int],
chunk_bytes: Optional[Union[int, str]],
compression: Optional[str],
existing_index: Optional[Dict[str, Any]] = None,
):
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._fn = fn
self._inputs = inputs
self.is_generator = False
self.existing_index = existing_index

self.check_fn()

Expand Down Expand Up @@ -292,6 +298,7 @@ def optimize(
reorder_files: bool = True,
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
mode: Optional[Literal["append", "overwrite"]] = None,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
Expand All @@ -315,8 +322,13 @@ def optimize(
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
batch_size: Group the inputs into batches of batch_size length.
mode: The mode to use when writing the data. Accepts either ``append`` or ``overwrite`` or None.
Defaults to None.
"""
if mode is not None and mode not in ["append", "overwrite"]:
raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.")

if isinstance(inputs, StreamingDataLoader) and batch_size is not None:
raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.")

Expand Down Expand Up @@ -353,7 +365,7 @@ def optimize(
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)

_assert_dir_has_index_file(_output_dir)
_assert_dir_has_index_file(_output_dir, mode=mode)

if not isinstance(inputs, StreamingDataLoader):
resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs))
Expand All @@ -366,6 +378,18 @@ def optimize(
if num_workers == 0:
num_workers = 1

num_workers = num_workers or _get_default_num_workers()
state_dict = {rank: 0 for rank in range(num_workers)}

existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None

if existing_index_file_content is not None:
for chunk in existing_index_file_content["chunks"]:
rank, index = extract_rank_and_index_from_filename(chunk["filename"])

if rank < num_workers and state_dict[rank] <= index:
state_dict[rank] = index + 1 # +1 because we want to start from the next index

data_processor = DataProcessor(
input_dir=resolved_dir,
output_dir=_output_dir,
Expand All @@ -375,6 +399,7 @@ def optimize(
num_uploaders=num_uploaders,
reorder_files=reorder_files,
reader=reader,
state_dict=state_dict,
)

with optimize_dns_context(True):
Expand All @@ -385,6 +410,7 @@ def optimize(
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
existing_index=existing_index_file_content,
)
)
return None
Expand Down
76 changes: 74 additions & 2 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# limitations under the License.

import io
import json
import os
import tempfile
import urllib
from contextlib import contextmanager
from subprocess import DEVNULL, Popen
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib import parse

from litdata.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE
from litdata.streaming.cache import Dir

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import (
Expand All @@ -27,6 +31,14 @@
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

try:
import boto3
import botocore

_BOTO3_AVAILABLE = True
except Exception:
_BOTO3_AVAILABLE = False


def _create_dataset(
input_dir: Optional[str],
Expand Down Expand Up @@ -177,3 +189,63 @@ def _get_work_dir() -> str:
assert project_id is not None
assert work_id is not None
return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/"


def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]:
"""Read the index file content."""
if not isinstance(output_dir, Dir):
raise ValueError("The provided output_dir should be a Dir object.")

if output_dir.url is None:
if output_dir.path is None:
return None
index_file_path = os.path.join(output_dir.path, _INDEX_FILENAME)
if not os.path.exists(index_file_path):
return None
with open(index_file_path) as f:
return json.load(f)

else:
# download the index file from s3, and read it
obj = parse.urlparse(output_dir.url)

if obj.scheme != "s3":
raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.")

# TODO: Add support for all cloud providers
s3 = boto3.client("s3")

prefix = obj.path.lstrip("/").rstrip("/") + "/"

# Check the index file exists
try:
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file_name = temp_file.name
s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name)
# Read data from the temporary file
with open(temp_file_name) as temp_file:
data = json.load(temp_file)
# Delete the temporary file
os.remove(temp_file_name)
return data
except botocore.exceptions.ClientError:
return None


def extract_rank_and_index_from_filename(chunk_filename: str) -> Tuple[int, int]:
"""Extract the rank and index from the filename.
It is assumed that the filename is in the format `chunk-<rank>-<index>.bin` or
`chunk-<rank>-<index>.compressionAlgorithm.bin`.
"""
# remove chunk and bin
chunk_filename = chunk_filename[6:-4].split("-") # (0, 0) or (0, 0.compressionAlgorithm)
assert len(chunk_filename) == 2

# get the rank and index
rank = int(chunk_filename[0])
index = int(chunk_filename[1].split(".")[0])

return rank, index
7 changes: 5 additions & 2 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
item_loader: Optional[BaseItemLoader] = None,
max_cache_size: Union[int, str] = "100GB",
serializers: Optional[Dict[str, Serializer]] = None,
writer_chunk_index: Optional[int] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
together in order to accelerate fetching.
Expand All @@ -56,6 +57,7 @@ def __init__(
item_loader: The object responsible to generate the chunk intervals and load an item froma chunk.
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
serializers: Provide your own serializers.
writer_chunk_index: The index of the chunk to start from when writing.
"""
super().__init__()
Expand All @@ -68,6 +70,7 @@ def __init__(
chunk_bytes=chunk_bytes,
compression=compression,
serializers=serializers,
chunk_index=writer_chunk_index or 0,
)
self._reader = BinaryReader(
self._cache_dir,
Expand Down Expand Up @@ -137,9 +140,9 @@ def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.merge(num_workers, node_rank=node_rank)

def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Optional[Dict[str, Any]] = None) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer._merge_no_wait(node_rank=node_rank)
self._writer._merge_no_wait(node_rank=node_rank, existing_index=existing_index)

def __len__(self) -> int:
return self._reader.get_length()
Expand Down
Loading

0 comments on commit fe6e026

Please sign in to comment.