Skip to content

Commit

Permalink
Merge branch 'main' into feat/add-support-encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy authored Jul 14, 2024
2 parents 736b136 + 225814e commit 5125e6e
Show file tree
Hide file tree
Showing 17 changed files with 210 additions and 138 deletions.
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,24 @@ dataset = StreamingDataset(..., max_cache_size="10GB")

</details>

<details>
<summary> ✅ Specify cache directory</summary>
&nbsp;

Specify the directory where cached files should be stored, ensuring efficient data retrieval and management. This is particularly useful for organizing your data storage and improving access times.

```python
from litdata import StreamingDataset
from litdata.streaming.cache import Dir

cache_dir = "/path/to/your/cache"
data_dir = "s3://my-bucket/my_optimized_dataset"

dataset = StreamingDataset(input_dir=Dir(path=cache_dir, url=data_dir))
```

</details>

<details>
<summary> ✅ Optimize loading on networked drives</summary>
&nbsp;
Expand All @@ -520,6 +538,48 @@ from litdata import StreamingDataset
dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")
```

</details>

<details>
<summary> ✅ Optimize dataset in distributed environment</summary>
&nbsp;

Lightning can distribute large workloads across hundreds of machines in parallel. This can reduce the time to complete a data processing task from weeks to minutes by scaling to enough machines.

To apply the optimize operator across multiple machines, simply provide the num_nodes and machine arguments to it as follows:

```python
import os
from litdata import optimize, Machine

def compress(index):
return (index, index ** 2)

optimize(
fn=compress,
inputs=list(range(100)),
num_workers=2,
output_dir="my_output",
chunk_bytes="64MB",
num_nodes=2,
machine=Machine.DATA_PREP, # You can select between dozens of optimized machines
)
```

If the `output_dir` is a local path, the optimized dataset will be present in: `/teamspace/jobs/{job_name}/nodes-0/my_output`. Otherwise, it will be stored in the specified `output_dir`.

Read the optimized dataset:

```python
from litdata import StreamingDataset

output_dir = "/teamspace/jobs/litdata-optimize-2024-07-08/nodes.0/my_output"

dataset = StreamingDataset(output_dir)

print(dataset[:])
```

</details>

&nbsp;
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_TQDM_AVAILABLE = RequirementCache("tqdm")

_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
18 changes: 6 additions & 12 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from urllib import parse

import boto3
import botocore
import numpy as np
import torch

from litdata.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_ENABLE_STATUS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_AVAILABLE,
_TQDM_AVAILABLE,
)
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
Expand All @@ -55,16 +55,6 @@
from litdata.utilities.encryption import Encryption
from litdata.utilities.packing import _pack_greedily

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import V1DatasetType

if _BOTO3_AVAILABLE:
import boto3
import botocore

logger = logging.Logger(__name__)


Expand Down Expand Up @@ -1044,6 +1034,8 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = 0
if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

pbar = _tqdm(
desc="Progress",
total=num_items,
Expand Down Expand Up @@ -1098,6 +1090,8 @@ def run(self, data_recipe: DataRecipe) -> None:
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO:
from lightning_cloud.openapi import V1DatasetType

_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
12 changes: 4 additions & 8 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import torch

from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import (
Expand All @@ -47,13 +47,7 @@
)
from litdata.utilities._pytree import tree_flatten
from litdata.utilities.encryption import Encryption

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator
from litdata.utilities.format import _get_tqdm_iterator_if_available


def _is_remote_file(path: str) -> bool:
Expand Down Expand Up @@ -567,6 +561,8 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore

_tqdm = _get_tqdm_iterator_if_available()

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)

Expand Down
11 changes: 3 additions & 8 deletions src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,12 @@
from abc import ABC, abstractmethod
from typing import Any, List

from litdata.constants import _TQDM_AVAILABLE
from litdata.imports import RequirementCache
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.utilities.format import _get_tqdm_iterator_if_available

_PYARROW_AVAILABLE = RequirementCache("pyarrow")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator


class BaseReader(ABC):
def get_num_nodes(self) -> int:
Expand Down Expand Up @@ -92,6 +85,8 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]:
cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}")
os.makedirs(cache_folder, exist_ok=True)

_tqdm = _get_tqdm_iterator_if_available()

for filepath in filepaths:
num_rows = self._get_num_rows(filepath)

Expand Down
24 changes: 8 additions & 16 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib import parse

from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE
from litdata.streaming.cache import Dir

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import (
ProjectIdDatasetsBody,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

try:
import boto3
import botocore
import boto3
import botocore

_BOTO3_AVAILABLE = True
except Exception:
_BOTO3_AVAILABLE = False
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.streaming.cache import Dir


def _create_dataset(
Expand Down Expand Up @@ -67,6 +55,10 @@ def _create_dataset(
if not storage_dir:
raise ValueError("The storage_dir should be defined.")

from lightning_cloud.openapi import ProjectIdDatasetsBody
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

client = LightningClient(retry=False)

try:
Expand Down
11 changes: 5 additions & 6 deletions src/litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from time import time
from typing import Any, Optional

from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO
import boto3
import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher

if _BOTO3_AVAILABLE:
import boto3
import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher
from litdata.constants import _IS_IN_STUDIO


class S3Client:
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __iter__(self) -> Iterator[Any]:
num_samples_yielded = None

if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded:
num_samples_yielded = self._num_samples_yielded[worker_env.rank]
num_samples_yielded = self._num_samples_yielded.get(worker_env.rank, 0)

self._iterator = _CombinedDatasetIterator(
self._datasets,
Expand Down
9 changes: 5 additions & 4 deletions src/litdata/streaming/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

TCompressor = TypeVar("TCompressor", bound="Compressor")

if _ZSTD_AVAILABLE:
import zstd


class Compressor(ABC):
"""Base class for compression algorithm."""
Expand All @@ -45,7 +42,7 @@ class ZSTDCompressor(Compressor):
def __init__(self, level: int) -> None:
super().__init__()
if not _ZSTD_AVAILABLE:
raise ModuleNotFoundError()
raise ModuleNotFoundError(str(_ZSTD_AVAILABLE))
self.level = level
self.extension = "zstd"

Expand All @@ -54,9 +51,13 @@ def name(self) -> str:
return f"{self.extension}:{self.level}"

def compress(self, data: bytes) -> bytes:
import zstd

return zstd.compress(data, self.level)

def decompress(self, data: bytes) -> bytes:
import zstd

return zstd.decompress(data)

@classmethod
Expand Down
12 changes: 5 additions & 7 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@
from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
from litdata.streaming.client import S3Client

if _GOOGLE_STORAGE_AVAILABLE:
from google.cloud import storage
else:

class storage: # type: ignore
Client = None


class Downloader(ABC):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
Expand Down Expand Up @@ -96,9 +89,14 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:

class GCPDownloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
if not _GOOGLE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))

super().__init__(remote_dir, cache_dir, chunks)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
from google.cloud import storage

obj = parse.urlparse(remote_filepath)

if obj.scheme != "gs":
Expand Down
Loading

0 comments on commit 5125e6e

Please sign in to comment.