Skip to content

Commit

Permalink
Merge branch 'main' into fix/combined-dataset-loading-states
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy authored Dec 17, 2024
2 parents 0ee0617 + 0a97def commit 424bc64
Show file tree
Hide file tree
Showing 29 changed files with 963 additions and 52 deletions.
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
# CI/CD and configs
/.github/ @borda
*.yml @borda

/src @tchaton
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<summary><b>Before submitting</b></summary>

- [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lit-data/blob/main/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lit-data/blob/main/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

Expand Down
355 changes: 355 additions & 0 deletions CONTRIBUTING.md

Large diffs are not rendered by default.

18 changes: 17 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
.PHONY: test clean docs
.PHONY: test clean docs install-pre-commit install-dependencies setup

# to imitate SLURM set only single node
export SLURM_LOCALID=0
# assume you have installed need packages
export SPHINX_MOCK_REQUIREMENTS=0

setup: install-dependencies install-pre-commit
@echo "==================== Setup Finished ===================="
@echo "All set! Ready to go!"

test: clean
pip install -q -r requirements.txt
pip install -q -r requirements/test.txt
Expand All @@ -28,3 +32,15 @@ clean:
rm -rf ./src/*.egg-info
rm -rf ./build
rm -rf ./dist

install-dependencies:
pip install -r requirements.txt
pip install -r requirements/test.txt
pip install -r requirements/docs.txt
pip install -r requirements/extras.txt
pip install -e .


install-pre-commit:
pip install pre-commit
pre-commit install
110 changes: 108 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,15 @@ storage_options = {
dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
```


Also, you can specify a custom cache directory when initializing your dataset. This is useful when you want to store the cache in a specific location.
```python
from litdata import StreamingDataset

# Initialize the StreamingDataset with the custom cache directory
dataset = StreamingDataset('s3://my-bucket/my-data', cache_dir="/path/to/cache")
```

</details>

<details>
Expand Down Expand Up @@ -390,6 +399,68 @@ for batch in tqdm(train_dataloader):

</details>

<details>
<summary> ✅ Filter illegal data </summary>
&nbsp;

Sometimes, you have bad data that you don't want to include in the optimized dataset. With LitData, yield only the good data sample to include.


```python
from litdata import optimize, StreamingDataset

def should_keep(index) -> bool:
# Replace with your own logic
return index % 2 == 0


def fn(data):
if should_keep(data):
yield data

if __name__ == "__main__":
optimize(
fn=fn,
inputs=list(range(1000)),
output_dir="only_even_index_optimized",
chunk_bytes="64MB",
num_workers=1
)

dataset = StreamingDataset("only_even_index_optimized")
data = list(dataset)
print(data)
# [0, 2, 4, 6, 8, 10, ..., 992, 994, 996, 998]
```

You can even use try/expect.

```python
from litdata import optimize, StreamingDataset

def fn(data):
try:
yield 1 / data
except:
pass

if __name__ == "__main__":
optimize(
fn=fn,
inputs=[0, 0, 0, 1, 2, 4, 0],
output_dir="only_defined_ratio_optimized",
chunk_bytes="64MB",
num_workers=1
)

dataset = StreamingDataset("only_defined_ratio_optimized")
data = list(dataset)
# The 0 are filtered out as they raise a division by zero
print(data)
# [1.0, 0.5, 0.25]
```
</details>

<details>
<summary> ✅ Combine datasets</summary>
&nbsp;
Expand Down Expand Up @@ -430,6 +501,41 @@ for batch in tqdm(train_dataloader):
```
</details>

<details>
<summary> ✅ Merge datasets</summary>
&nbsp;

Merge multiple optimized datasets into one.

```python
import numpy as np
from PIL import Image

from litdata import StreamingDataset, merge_datasets, optimize


def random_images(index):
return {
"index": index,
"image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)),
"class": np.random.randint(10),
}


if __name__ == "__main__":
out_dirs = ["fast_data_1", "fast_data_2", "fast_data_3", "fast_data_4"] # or ["s3://my-bucket/fast_data_1", etc.]"
for out_dir in out_dirs:
optimize(fn=random_images, inputs=list(range(250)), output_dir=out_dir, num_workers=4, chunk_bytes="64MB")

merged_out_dir = "merged_fast_data" # or "s3://my-bucket/merged_fast_data"
merge_datasets(input_dirs=out_dirs, output_dir=merged_out_dir)

dataset = StreamingDataset(merged_out_dir)
print(len(dataset))
# out: 1000
```
</details>

<details>
<summary> ✅ Split datasets for train, val, test</summary>

Expand Down Expand Up @@ -866,7 +972,7 @@ Speed to stream Imagenet 1.2M from AWS S3:

| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) |
|---|---|---|---|---|
| PL Data | **5800** | **6589** | **6282** | **7221** |
| LitData | **5800** | **6589** | **6282** | **7221** |
| Web Dataset | 3134 | 3924 | 3343 | 4424 |
| Mosaic ML | 2898 | 5099 | 2809 | 5158 |

Expand All @@ -887,7 +993,7 @@ LitData optimizes the Imagenet dataset for fast training 3-5x faster than other
Time to optimize 1.2 million ImageNet images (Faster is better):
| Framework |Train Conversion Time | Val Conversion Time | Dataset Size | # Files |
|---|---|---|---|---|
| PL Data | **10:05 min** | **00:30 min** | **143.1 GB** | 2.339 |
| LitData | **10:05 min** | **00:30 min** | **143.1 GB** | 2.339 |
| Web Dataset | 32:36 min | 01:22 min | 147.8 GB | 1.144 |
| Mosaic ML | 49:49 min | 01:04 min | **143.1 GB** | 2.298 |

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ filelock
numpy
boto3
requests
tifffile
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple =
"Natural Language :: English",
# How mature is this project? Common values are
# 3 - Alpha, 4 - Beta, 5 - Production/Stable
"Development Status :: 3 - Alpha",
"Development Status :: 4 - Beta",
# Indicate who your project is intended for
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand All @@ -92,6 +92,8 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple =
# Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both.
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import time

__version__ = "0.2.28"
__version__ = "0.2.34"
__author__ = "Lightning AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
61 changes: 61 additions & 0 deletions src/litdata/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import functools
import warnings
from typing import Any, Optional

import requests
from packaging import version as packaging_version


class WarningCache(set):
"""Cache for warnings."""

def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
"""Trigger warning message."""
if message not in self:
self.add(message)
warnings.warn(message, stacklevel=stacklevel, **kwargs)


warning_cache = WarningCache()

__package_name__ = "litdata"


@functools.lru_cache(maxsize=1)
def _get_newer_version(curr_version: str) -> Optional[str]:
"""Check PyPI for newer versions of ``litdata``.
Returning the newest version if different from the current or ``None`` otherwise.
"""
if packaging_version.parse(curr_version).is_prerelease:
return None
try:
response = requests.get(f"https://pypi.org/pypi/{__package_name__}/json", timeout=30)
response_json = response.json()
releases = response_json["releases"]
if curr_version not in releases:
# Always return None if not installed from PyPI (e.g. dev versions)
return None
latest_version = response_json["info"]["version"]
parsed_version = packaging_version.parse(latest_version)
is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
return None if curr_version == latest_version or is_invalid else latest_version
except requests.exceptions.RequestException:
return None


def _check_version_and_prompt_upgrade(curr_version: str) -> None:
"""Checks that the current version of ``litdata`` is the latest on PyPI.
If not, warn the user to upgrade ``litdata``.
"""
new_version = _get_newer_version(curr_version)
if new_version:
warning_cache.warn(
f"A newer version of {__package_name__} is available ({new_version}). "
f"Please consider upgrading with `pip install -U {__package_name__}`. "
"Not all functionalities of the platform can be guaranteed to work with the current version.",
)
return
33 changes: 25 additions & 8 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def run(self) -> None:
try:
self._setup()
self._loop()
self._terminate()
except Exception:
traceback_format = traceback.format_exc()
self.error_queue.put(traceback_format)
Expand All @@ -469,6 +470,19 @@ def _setup(self) -> None:
self._start_uploaders()
self._start_remover()

def _terminate(self) -> None:
"""Make sure all the uploaders, downloaders and removers are terminated."""
for uploader in self.uploaders:
if uploader.is_alive():
uploader.join()

for downloader in self.downloaders:
if downloader.is_alive():
downloader.join()

if self.remover and self.remover.is_alive():
self.remover.join()

def _loop(self) -> None:
num_downloader_finished = 0

Expand Down Expand Up @@ -795,7 +809,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul

chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
if chunks and delete_cached_files and output_dir.path is not None:
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks}")
raise RuntimeError(f"All the chunks should have been deleted. Found {chunks} in cache: {cache_dir}")

merge_cache = Cache(cache_dir, chunk_bytes=1)
node_rank = _get_node_rank()
Expand Down Expand Up @@ -1110,25 +1124,28 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = new_total
if current_total == num_items:
# make sure all processes are terminated
for w in self.workers:
if w.is_alive():
w.join()
break

if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
with open("status.json", "w") as f:
json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f)

# Exit early if all the workers are done.
# This means there were some kinda of errors.
# This means either there were some kinda of errors, or optimize function was very small.
if all(not w.is_alive() for w in self.workers):
raise RuntimeError("One of the worker has failed")
try:
error = self.error_queue.get(timeout=0.01)
self._exit_on_error(error)
except Empty:
continue

if _TQDM_AVAILABLE:
pbar.close()

# TODO: Check whether this is still required.
if num_nodes == 1:
for w in self.workers:
w.join()

print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

Expand Down
Loading

0 comments on commit 424bc64

Please sign in to comment.