diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 5ffd4453..ee4e63eb 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -7,3 +7,5 @@
# CI/CD and configs
/.github/ @borda
*.yml @borda
+
+/src @tchaton
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 2f260695..50f253f9 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -2,7 +2,7 @@
Before submitting
- [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
-- [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lit-data/blob/main/.github/CONTRIBUTING.md), Pull Request section?
+- [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lit-data/blob/main/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..43947cdc
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,355 @@
+# Contributing
+
+Welcome to the PyTorch Lightning community! We're building the most advanced research platform on the planet to implement the latest, best practices and integrations that the amazing PyTorch team and other research organization rolls out!
+
+If you are new to open source, check out [this blog to get started with your first Open Source contribution](https://devblog.pytorchlightning.ai/quick-contribution-guide-86d977171b3a).
+
+## Main Core Value: One less thing to remember
+
+Simplify the API as much as possible from the user perspective.
+Any additions or improvements should minimize the things the user needs to remember.
+
+For example: One benefit of the `validation_step` is that the user doesn't have to remember to set the model to .eval().
+This helps users avoid all sorts of subtle errors.
+
+## Lightning Design Principles
+
+We encourage all sorts of contributions you're interested in adding! When coding for Lightning, please follow these principles.
+
+### No PyTorch Interference
+
+We don't want to add any abstractions on top of pure PyTorch.
+This gives researchers all the control they need without having to learn yet another framework.
+
+### Simple Internal Code
+
+It's useful for users to look at the code and understand very quickly what's happening.
+Many users won't be engineers. Thus we need to value clear, simple code over condensed ninja moves.
+While that's super cool, this isn't the project for that :)
+
+### Force User Decisions To Best Practices
+
+There are 1,000 ways to do something. However, eventually one popular solution becomes standard practice, and everyone follows.
+We try to find the best way to solve a particular problem, and then force our users to use it for readability and simplicity.
+A good example is accumulated gradients.
+There are many different ways to implement it, we just pick one and force users to use it.
+A bad forced decision would be to make users use a specific library to do something.
+
+When something becomes a best practice, we add it to the framework. This is usually something like bits of code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it.
+
+### Simple External API
+
+What makes sense to you may not make sense to others. When creating an issue with an API change suggestion, please validate that it makes sense for others.
+Treat code changes the way you treat a startup: validate that it's a needed feature, then add if it makes sense for many people.
+
+### Backward-compatible API
+
+We all hate updating our deep learning packages because we don't want to refactor a bunch of stuff. In Lightning, we make sure every change we make which could break an API is backward compatible with good deprecation warnings.
+
+**You shouldn't be afraid to upgrade Lightning :)**
+
+### Gain User Trust
+
+As a researcher, you can't have any part of your code going wrong. So, make thorough tests to ensure that every implementation of a new trick or subtle change is correct.
+
+### Interoperability
+
+Have a favorite feature from other libraries like fast.ai or transformers? Those should just work with lightning as well. Grab your favorite model or learning rate scheduler from your favorite library and run it in Lightning.
+
+______________________________________________________________________
+
+## Contribution Types
+
+We are always open to contributions of new features or bug fixes.
+
+A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc...) so we're in a good state there thanks to all the early contributors (even pre-beta release)!
+
+### Bug Fixes:
+
+1. If you find a bug please submit a GitHub issue.
+
+ - Make sure the title explains the issue.
+ - Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.
+ - Add details on how to reproduce the issue - a minimal test case is always best, colab is also great.
+ Note, that the sample code shall be minimal and if needed with publicly available data.
+
+2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach:
+
+ - Convert your minimal code example to a unit/integration test with assert on expected results.
+ - Start by debugging the issue... You can run just this particular test in your IDE and draft a fix.
+ - Verify that your test case fails on the master branch and only passes with the fix applied.
+
+3. Submit a PR!
+
+_**Note**, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution, and we can help you or finish it with you :\]_
+
+### New Features:
+
+1. Submit a GitHub issue - describe what is the motivation of such feature (adding the use case, or an example is helpful).
+
+2. Determine the feature scope with us.
+
+3. Submit a PR! We recommend test driven approach to adding new features as well:
+
+ - Write a test for the functionality you want to add.
+ - Write the functional code until the test passes.
+
+4. Add/update the relevant tests!
+
+- [This PR](https://github.com/Lightning-AI/litdata/pull/237) is a good example for `Fix uneven batches in distributed dataloading`. It writes relevant tests and fixes the code, and also has a good description of the problem and solution.
+
+### Test cases:
+
+Want to keep Lightning healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features.
+
+Want to add a new test case and not sure how? [Talk to us!](https://lightning.ai/community)
+
+______________________________________________________________________
+
+## Guidelines
+
+### Original code
+
+All added or edited code shall be the own original work of the particular contributor.
+If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code's author. For example - `This code is inspired from http://...`.
+In case you are adding new dependencies, make sure that they are compatible with the actual [LitData license](https://github.com/Lightning-AI/litdata/blob/main/LICENSE) (i.e. dependencies should be _at least_ as permissive as the LitData license).
+
+### Coding Style
+
+1. Use f-strings for output formation (except logging when we stay with lazy `logging.info("Hello %s!", name)`).
+2. You can use [pre-commit](https://pre-commit.com/) to make sure your code style is correct.
+
+### Documentation
+
+To learn about development of docs, check out the docs [README.md](https://github.com/Lightning-AI/litdata/blob/main/README.md).
+
+### Pull Request
+
+We welcome any useful contribution! For your convenience here's a recommended workflow:
+
+1. Think about what you want to do - fix a bug, repair docs, etc. If you want to implement a new feature or enhance an existing one.
+
+ - Start by opening a GitHub issue to explain the feature and the motivation.
+ In the case of features, ask yourself first - Is this NECESSARY for Lightning? There are some PRs that are just purely about adding engineering complexity which has no place in Lightning.
+ - Core contributors will take a look (it might take some time - we are often overloaded with issues!) and discuss it.
+ - Once an agreement was reached - start coding.
+
+2. Start your work locally.
+
+ - Create a branch and prepare your changes.
+ - Tip: do not work on your master branch directly, it may become complicated when you need to rebase.
+ - Tip: give your PR a good name! It will be useful later when you may work on multiple tasks/PRs.
+
+3. Test your code!
+
+ - It is always good practice to start coding by creating a test case, verifying it breaks with current behaviour, and passes with your new changes.
+ - Make sure your new tests cover all different edge cases.
+ - Make sure all exceptions raised are tested.
+ - Make sure all warnings raised are tested.
+
+4. If your PR is not ready for reviews, but you want to run it on our CI, open a "Draft PR" to let us know you don't need feedback yet.
+
+5. If any of the existing tests fail in your PR on our CI, identify what's failing and try to address it.
+
+6. When you feel ready for integrating your work, mark your PR "Ready for review".
+
+ - Your code should be readable and follow the project's design principles.
+ - Make sure all tests are passing and any new code is tested for (coverage!).
+ - Make sure you link the GitHub issue to your PR.
+ - Make sure any docs for that piece of code are updated, or added.
+ - The code should be elegant and simple. No over-engineering or hard-to-read code.
+
+ Do your best but don't sweat about perfection! We do code-review to find any missed items.
+ If you need help, don't hesitate to ping the core team on the PR.
+
+7. Use tags in PR name for the following cases:
+
+ - **\[blocked by #\]** if your work is dependent on other PRs.
+ - **\[wip\]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime.
+
+______________________________________________________________________
+
+### Setup LitData
+
+We use Makefile for easier setup. To setup LitData, follow these steps:
+
+1. Clone the forked-repository
+
+```bash
+git clone https://github.com/{YOUR_USERNAME}/litdata.git
+cd litdata
+```
+
+2. Optional, but recommended: Create a virtual conda environment. If you're using [Lightning Studio](https://lightning.ai/studios), skip this step.
+
+```bash
+conda create -n litdata python=3.10
+conda activate litdata
+```
+
+3. Make sure you have `make` installed
+
+```bash
+# for debian-based systems: sudo apt-get install build-essential
+# for mac: brew install make
+make --version
+```
+
+4. Run command:
+
+```bash
+make setup
+```
+
+That's it. You are ready to go! 🎉
+
+______________________________________________________________________
+
+### Question & Answer
+
+#### How can I help/contribute?
+
+All types of contributions are welcome - reporting bugs, fixing documentation, adding test cases, solving issues, and preparing bug fixes.
+To get started with code contributions, look for issues marked with the label [good first issue](https://github.com/Lightning-AI/litdata/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) or chose something close to your domain with the label [help wanted](https://github.com/Lightning-AI/litdata/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22). Before coding, make sure that the issue description is clear and comment on the issue so that we can assign it to you (or simply self-assign if you can).
+
+#### Is there a recommendation for branch names?
+
+We recommend you follow this convention `/_` where the types are: `bugfix`, `feature`, `docs`, or `tests` (but if you are using your own fork that's optional).
+
+#### How to rebase my PR?
+
+We recommend creating a PR in a separate branch other than `master`, especially if you plan to submit several changes and do not want to wait until the first one is resolved (we can work on them in parallel).
+
+First, make sure you have set [upstream](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/configuring-a-remote-repository-for-a-fork) by running:
+
+```bash
+git remote add upstream https://github.com/Lightning-AI/litdata.git
+```
+
+You'll know its set up right if you run `git remote -v` and see something similar to this:
+
+```bash
+origin https://github.com/{YOUR_USERNAME}/litdata.git (fetch)
+origin https://github.com/{YOUR_USERNAME}/litdata.git (push)
+upstream https://github.com/Lightning-AI/litdata.git (fetch)
+upstream https://github.com/Lightning-AI/litdata.git (push)
+```
+
+Checkout your feature branch and rebase it with upstream's master before pushing up your feature branch:
+
+```bash
+git fetch --all --prune
+git rebase upstream/master
+# follow git instructions to resolve conflicts
+git push -f origin {BRANCH_NAME} # Use -f if you've pushed this branch before
+```
+
+#### How to add new tests?
+
+We are using [pytest](https://docs.pytest.org/en/stable/) in LitData.
+
+Here are tutorials:
+
+- (recommended) [Visual Testing with pytest](https://www.youtube.com/playlist?list=PLCTHcU1KoD99Rim2tzg-IhYY2iu9FFvNo) from JetBrains on YouTube
+- [Effective Python Testing With Pytest](https://realpython.com/pytest-python-testing/) article on realpython.com
+
+Here is the process to create a new test
+
+- 0. Optional: Follow tutorials!
+- 1. Find a file in tests/ which matches what you want to test. If none, create one.
+- 2. Use this template to get started!
+
+```python
+# TEST SHOULD BE IN YOUR FILE: tests/.../test_file.py
+# TEST CODE TEMPLATE
+
+
+# [OPTIONAL] pytest decorator
+# @RunIf(min_cuda_gpus=1)
+def test_explain_what_is_being_tested(tmpdir):
+ """
+ Test description about text reason to be
+ """
+ cache_dir = os.path.join(tmpdir, "cache_dir")
+
+ # assert the behaviour is correct.
+ assert ...
+```
+
+run our/your test with
+
+```bash
+python -m pytest tests/.../test_file.py::test_explain_what_is_being_tested -v --capture=no
+```
+
+#### How to fix PR with mixed base and target branches?
+
+Sometimes you start your PR as a bug-fix but it turns out to be more of a feature (or the other way around).
+Do not panic, the solution is very straightforward and quite simple.
+All you need to do are these two steps in arbitrary order:
+
+- Ask someone from Core to change the base/target branch to the correct one
+- Rebase or cherry-pick your commits onto the correct base branch...
+
+Let's show how to deal with the git...
+the sample case is moving a PR from `master` to `release/1.2-dev` assuming my branch name is `my-branch`
+and the last true master commit is `ccc111` and your first commit is `mmm222`.
+
+- **Cherry-picking** way
+ ```bash
+ git checkout my-branch
+ # create a local backup of your branch
+ git branch my-branch-backup
+ # reset your branch to the correct base
+ git reset release/1.2-dev --hard
+ # ACTION: this step is much easier to do with IDE
+ # so open one and cherry-pick your last commits from `my-branch-backup`
+ # resolve all eventual conflict as the new base may contain different code
+ # when all done, push back to the open PR
+ git push -f
+ ```
+- **Rebasing way**, see more about [rebase onto usage](https://www.atlassian.com/git/tutorials/rewriting-history/git-rebase)
+ ```bash
+ git checkout my-branch
+ # rebase your commits on the correct branch
+ git rebase --onto release/1.2-dev ccc111
+ # if there is no collision you shall see just success
+ # eventually you would need to resolve collision and in such case follow the instruction in terminal
+ # when all done, push back to the open PR
+ git push -f
+ ```
+
+### Bonus Workflow Tip
+
+If you don't want to remember all the commands above every time you want to push some code/setup a Lightning Dev environment on a new VM, you can set up bash aliases for some common commands. You can add these to one of your `~/.bashrc`, `~/.zshrc`, or `~/.bash_aliases` files.
+
+NOTE: Once you edit one of these files, remember to `source` it or restart your shell. (ex. `source ~/.bashrc` if you added these to your `~/.bashrc` file).
+
+```bash
+plclone (){
+ git clone https://github.com/{YOUR_USERNAME}/litdata.git
+ cd litdata
+ git remote add upstream https://github.com/Lightning-AI/litdata.git
+ # This is just here to print out info about your remote upstream/origin
+ git remote -v
+}
+
+plfetch (){
+ git fetch --all --prune
+ git checkout master
+ git merge upstream/master
+}
+
+# Rebase your branch with upstream's master
+# plrebase
+plrebase (){
+ git checkout $@
+ git rebase master
+}
+```
+
+Now, you can:
+
+- clone your fork and set up upstream by running `plclone` from your terminal
+- fetch upstream and update your local master branch with it by running `plfetch`
+- rebase your feature branch (after running `plfetch`) by running `plrebase your-branch-name`
diff --git a/Makefile b/Makefile
index cf5250aa..04571097 100644
--- a/Makefile
+++ b/Makefile
@@ -1,10 +1,14 @@
-.PHONY: test clean docs
+.PHONY: test clean docs install-pre-commit install-dependencies setup
# to imitate SLURM set only single node
export SLURM_LOCALID=0
# assume you have installed need packages
export SPHINX_MOCK_REQUIREMENTS=0
+setup: install-dependencies install-pre-commit
+ @echo "==================== Setup Finished ===================="
+ @echo "All set! Ready to go!"
+
test: clean
pip install -q -r requirements.txt
pip install -q -r requirements/test.txt
@@ -28,3 +32,15 @@ clean:
rm -rf ./src/*.egg-info
rm -rf ./build
rm -rf ./dist
+
+install-dependencies:
+ pip install -r requirements.txt
+ pip install -r requirements/test.txt
+ pip install -r requirements/docs.txt
+ pip install -r requirements/extras.txt
+ pip install -e .
+
+
+install-pre-commit:
+ pip install pre-commit
+ pre-commit install
diff --git a/README.md b/README.md
index e03d0c13..0bbd5090 100644
--- a/README.md
+++ b/README.md
@@ -225,6 +225,15 @@ storage_options = {
dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
```
+
+Also, you can specify a custom cache directory when initializing your dataset. This is useful when you want to store the cache in a specific location.
+```python
+from litdata import StreamingDataset
+
+# Initialize the StreamingDataset with the custom cache directory
+dataset = StreamingDataset('s3://my-bucket/my-data', cache_dir="/path/to/cache")
+```
+
@@ -390,6 +399,68 @@ for batch in tqdm(train_dataloader):
+
+ ✅ Filter illegal data
+
+
+Sometimes, you have bad data that you don't want to include in the optimized dataset. With LitData, yield only the good data sample to include.
+
+
+```python
+from litdata import optimize, StreamingDataset
+
+def should_keep(index) -> bool:
+ #Â Replace with your own logic
+ return index % 2 == 0
+
+
+def fn(data):
+ if should_keep(data):
+ yield data
+
+if __name__ == "__main__":
+ optimize(
+ fn=fn,
+ inputs=list(range(1000)),
+ output_dir="only_even_index_optimized",
+ chunk_bytes="64MB",
+ num_workers=1
+ )
+
+ dataset = StreamingDataset("only_even_index_optimized")
+ data = list(dataset)
+ print(data)
+ # [0, 2, 4, 6, 8, 10, ..., 992, 994, 996, 998]
+```
+
+You can even use try/expect.
+
+```python
+from litdata import optimize, StreamingDataset
+
+def fn(data):
+ try:
+ yield 1 / data
+ except:
+ pass
+
+if __name__ == "__main__":
+ optimize(
+ fn=fn,
+ inputs=[0, 0, 0, 1, 2, 4, 0],
+ output_dir="only_defined_ratio_optimized",
+ chunk_bytes="64MB",
+ num_workers=1
+ )
+
+ dataset = StreamingDataset("only_defined_ratio_optimized")
+ data = list(dataset)
+ # The 0 are filtered out as they raise a division by zero
+ print(data)
+ # [1.0, 0.5, 0.25]
+```
+
+
✅ Combine datasets
@@ -430,6 +501,41 @@ for batch in tqdm(train_dataloader):
```
+
+ ✅ Merge datasets
+
+
+Merge multiple optimized datasets into one.
+
+```python
+import numpy as np
+from PIL import Image
+
+from litdata import StreamingDataset, merge_datasets, optimize
+
+
+def random_images(index):
+ return {
+ "index": index,
+ "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)),
+ "class": np.random.randint(10),
+ }
+
+
+if __name__ == "__main__":
+ out_dirs = ["fast_data_1", "fast_data_2", "fast_data_3", "fast_data_4"] # or ["s3://my-bucket/fast_data_1", etc.]"
+ for out_dir in out_dirs:
+ optimize(fn=random_images, inputs=list(range(250)), output_dir=out_dir, num_workers=4, chunk_bytes="64MB")
+
+ merged_out_dir = "merged_fast_data" # or "s3://my-bucket/merged_fast_data"
+ merge_datasets(input_dirs=out_dirs, output_dir=merged_out_dir)
+
+ dataset = StreamingDataset(merged_out_dir)
+ print(len(dataset))
+ # out: 1000
+```
+
+
✅ Split datasets for train, val, test
@@ -866,7 +972,7 @@ Speed to stream Imagenet 1.2M from AWS S3:
| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) |
|---|---|---|---|---|
-| PL Data | **5800** | **6589** | **6282** | **7221** |
+| LitData | **5800** | **6589** | **6282** | **7221** |
| Web Dataset | 3134 | 3924 | 3343 | 4424 |
| Mosaic ML | 2898 | 5099 | 2809 | 5158 |
@@ -887,7 +993,7 @@ LitData optimizes the Imagenet dataset for fast training 3-5x faster than other
Time to optimize 1.2 million ImageNet images (Faster is better):
| Framework |Train Conversion Time | Val Conversion Time | Dataset Size | # Files |
|---|---|---|---|---|
-| PL Data | **10:05 min** | **00:30 min** | **143.1 GB** | 2.339 |
+| LitData | **10:05 min** | **00:30 min** | **143.1 GB** | 2.339 |
| Web Dataset | 32:36 min | 01:22 min | 147.8 GB | 1.144 |
| Mosaic ML | 49:49 min | 01:04 min | **143.1 GB** | 2.298 |
diff --git a/requirements.txt b/requirements.txt
index 06a629a0..76da06d2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,3 +4,4 @@ filelock
numpy
boto3
requests
+tifffile
diff --git a/setup.py b/setup.py
index 441b90d2..61663b5d 100644
--- a/setup.py
+++ b/setup.py
@@ -81,7 +81,7 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple =
"Natural Language :: English",
# How mature is this project? Common values are
# 3 - Alpha, 4 - Beta, 5 - Production/Stable
- "Development Status :: 3 - Alpha",
+ "Development Status :: 4 - Beta",
# Indicate who your project is intended for
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
@@ -92,6 +92,8 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple =
# Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both.
"Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
diff --git a/src/litdata/__about__.py b/src/litdata/__about__.py
index 9aed6d3d..82436a84 100644
--- a/src/litdata/__about__.py
+++ b/src/litdata/__about__.py
@@ -14,7 +14,7 @@
import time
-__version__ = "0.2.28"
+__version__ = "0.2.34"
__author__ = "Lightning AI et al."
__author_email__ = "pytorch@lightning.ai"
__license__ = "Apache-2.0"
diff --git a/src/litdata/constants.py b/src/litdata/constants.py
index a6a714c7..f6e02fad 100644
--- a/src/litdata/constants.py
+++ b/src/litdata/constants.py
@@ -23,6 +23,7 @@
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
+_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
diff --git a/src/litdata/helpers.py b/src/litdata/helpers.py
new file mode 100644
index 00000000..5a343415
--- /dev/null
+++ b/src/litdata/helpers.py
@@ -0,0 +1,61 @@
+import functools
+import warnings
+from typing import Any, Optional
+
+import requests
+from packaging import version as packaging_version
+
+
+class WarningCache(set):
+ """Cache for warnings."""
+
+ def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
+ """Trigger warning message."""
+ if message not in self:
+ self.add(message)
+ warnings.warn(message, stacklevel=stacklevel, **kwargs)
+
+
+warning_cache = WarningCache()
+
+__package_name__ = "litdata"
+
+
+@functools.lru_cache(maxsize=1)
+def _get_newer_version(curr_version: str) -> Optional[str]:
+ """Check PyPI for newer versions of ``litdata``.
+
+ Returning the newest version if different from the current or ``None`` otherwise.
+
+ """
+ if packaging_version.parse(curr_version).is_prerelease:
+ return None
+ try:
+ response = requests.get(f"https://pypi.org/pypi/{__package_name__}/json", timeout=30)
+ response_json = response.json()
+ releases = response_json["releases"]
+ if curr_version not in releases:
+ # Always return None if not installed from PyPI (e.g. dev versions)
+ return None
+ latest_version = response_json["info"]["version"]
+ parsed_version = packaging_version.parse(latest_version)
+ is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
+ return None if curr_version == latest_version or is_invalid else latest_version
+ except requests.exceptions.RequestException:
+ return None
+
+
+def _check_version_and_prompt_upgrade(curr_version: str) -> None:
+ """Checks that the current version of ``litdata`` is the latest on PyPI.
+
+ If not, warn the user to upgrade ``litdata``.
+
+ """
+ new_version = _get_newer_version(curr_version)
+ if new_version:
+ warning_cache.warn(
+ f"A newer version of {__package_name__} is available ({new_version}). "
+ f"Please consider upgrading with `pip install -U {__package_name__}`. "
+ "Not all functionalities of the platform can be guaranteed to work with the current version.",
+ )
+ return
diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py
index f1af9afa..beb6b3d4 100644
--- a/src/litdata/processing/data_processor.py
+++ b/src/litdata/processing/data_processor.py
@@ -456,6 +456,7 @@ def run(self) -> None:
try:
self._setup()
self._loop()
+ self._terminate()
except Exception:
traceback_format = traceback.format_exc()
self.error_queue.put(traceback_format)
@@ -469,6 +470,19 @@ def _setup(self) -> None:
self._start_uploaders()
self._start_remover()
+ def _terminate(self) -> None:
+ """Make sure all the uploaders, downloaders and removers are terminated."""
+ for uploader in self.uploaders:
+ if uploader.is_alive():
+ uploader.join()
+
+ for downloader in self.downloaders:
+ if downloader.is_alive():
+ downloader.join()
+
+ if self.remover and self.remover.is_alive():
+ self.remover.join()
+
def _loop(self) -> None:
num_downloader_finished = 0
@@ -795,7 +809,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
if chunks and delete_cached_files and output_dir.path is not None:
- raise RuntimeError(f"All the chunks should have been deleted. Found {chunks}")
+ raise RuntimeError(f"All the chunks should have been deleted. Found {chunks} in cache: {cache_dir}")
merge_cache = Cache(cache_dir, chunk_bytes=1)
node_rank = _get_node_rank()
@@ -1110,6 +1124,10 @@ def run(self, data_recipe: DataRecipe) -> None:
current_total = new_total
if current_total == num_items:
+ # make sure all processes are terminated
+ for w in self.workers:
+ if w.is_alive():
+ w.join()
break
if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
@@ -1117,18 +1135,17 @@ def run(self, data_recipe: DataRecipe) -> None:
json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f)
# Exit early if all the workers are done.
- # This means there were some kinda of errors.
+ # This means either there were some kinda of errors, or optimize function was very small.
if all(not w.is_alive() for w in self.workers):
- raise RuntimeError("One of the worker has failed")
+ try:
+ error = self.error_queue.get(timeout=0.01)
+ self._exit_on_error(error)
+ except Empty:
+ continue
if _TQDM_AVAILABLE:
pbar.close()
- # TODO: Check whether this is still required.
- if num_nodes == 1:
- for w in self.workers:
- w.join()
-
print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py
index dd62909a..f231ddd9 100644
--- a/src/litdata/processing/functions.py
+++ b/src/litdata/processing/functions.py
@@ -27,7 +27,9 @@
import torch
+from litdata import __version__
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
+from litdata.helpers import _check_version_and_prompt_upgrade
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import (
@@ -222,6 +224,8 @@ def map(
batch_size: Group the inputs into batches of batch_size length.
"""
+ _check_version_and_prompt_upgrade(__version__)
+
if isinstance(inputs, StreamingDataLoader) and batch_size is not None:
raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.")
@@ -351,6 +355,8 @@ def optimize(
inside an interactive shell like Ipython.
"""
+ _check_version_and_prompt_upgrade(__version__)
+
if mode is not None and mode not in ["append", "overwrite"]:
raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.")
@@ -521,12 +527,13 @@ class CopyInfo:
new_filename: str
-def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
+def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional[int] = os.cpu_count()) -> None:
"""Enables to merge multiple existing optimized datasets into a single optimized dataset.
Args:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
+ max_workers: Number of workers for multithreading
"""
if len(input_dirs) == 0:
@@ -537,6 +544,7 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs]
resolved_output_dir = _resolve_dir(output_dir)
+ max_workers = max_workers or 1
if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")
@@ -580,8 +588,11 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
_tqdm = _get_tqdm_iterator_if_available()
- for copy_info in _tqdm(copy_infos):
- _apply_copy(copy_info, resolved_output_dir)
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures: List[concurrent.futures.Future] = []
+ for copy_info in _tqdm(copy_infos):
+ future = executor.submit(_apply_copy, copy_info, resolved_output_dir)
+ futures.append(future)
_save_index(index_json, resolved_output_dir)
diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py
index d24803c3..68c14911 100644
--- a/src/litdata/streaming/client.py
+++ b/src/litdata/streaming/client.py
@@ -37,7 +37,7 @@ def _create_client(self) -> None:
os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials"
)
- if has_shared_credentials_file or not _IS_IN_STUDIO:
+ if has_shared_credentials_file or not _IS_IN_STUDIO or self._storage_options:
self._client = boto3.client(
"s3",
**{
diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py
index df0ea012..7c43d435 100644
--- a/src/litdata/streaming/config.py
+++ b/src/litdata/streaming/config.py
@@ -226,7 +226,12 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
begin = self._intervals[index.chunk_index][0]
- return local_chunkpath, begin, chunk["chunk_bytes"]
+ filesize_bytes = chunk["chunk_bytes"]
+
+ if self._config and self._config.get("encryption") is None:
+ filesize_bytes += (1 + chunk["chunk_size"]) * 4
+
+ return local_chunkpath, begin, filesize_bytes
def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
"""Retrieves the associated chunk_index for a given chunk filename."""
diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py
index ea82ce7a..5b13d44f 100644
--- a/src/litdata/streaming/dataset.py
+++ b/src/litdata/streaming/dataset.py
@@ -19,9 +19,11 @@
import numpy as np
from torch.utils.data import IterableDataset
+from litdata import __version__
from litdata.constants import (
_INDEX_FILENAME,
)
+from litdata.helpers import _check_version_and_prompt_upgrade
from litdata.streaming import Cache
from litdata.streaming.downloader import get_downloader_cls # noqa: F401
from litdata.streaming.item_loader import BaseItemLoader
@@ -46,6 +48,7 @@ class StreamingDataset(IterableDataset):
def __init__(
self,
input_dir: Union[str, "Dir"],
+ cache_dir: Optional[Union[str, "Dir"]] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: Optional[bool] = None,
@@ -61,6 +64,8 @@ def __init__(
Args:
input_dir: Path to the folder where the input data is stored.
+ cache_dir: Path to the folder where the cache data is stored. If not provided, the cache will be stored
+ in the default cache directory.
item_loader: The logic to load an item from a chunk.
shuffle: Whether to shuffle the data.
drop_last: If `True`, drops the last items to ensure that
@@ -76,6 +81,8 @@ def __init__(
max_pre_download: Maximum number of chunks that can be pre-downloaded by the StreamingDataset.
"""
+ _check_version_and_prompt_upgrade(__version__)
+
super().__init__()
if not isinstance(shuffle, bool):
raise ValueError(f"Shuffle should be a boolean. Found {shuffle}")
@@ -84,12 +91,14 @@ def __init__(
raise ValueError("subsample must be a float with value between 0 and 1.")
input_dir = _resolve_dir(input_dir)
+ cache_dir = _resolve_dir(cache_dir)
self.input_dir = input_dir
+ self.cache_dir = cache_dir
self.subsampled_files: List[str] = []
self.region_of_interest: List[Tuple[int, int]] = []
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
- self.input_dir, item_loader, subsample, shuffle, seed, storage_options
+ self.input_dir, self.cache_dir, item_loader, subsample, shuffle, seed, storage_options
)
self.item_loader = item_loader
@@ -155,7 +164,8 @@ def set_epoch(self, current_epoch: int) -> None:
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
if _should_replace_path(self.input_dir.path):
cache_path = _try_create_cache_dir(
- input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url
+ input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url,
+ cache_dir=self.cache_dir.path,
)
if cache_path is not None:
self.input_dir.path = cache_path
@@ -399,6 +409,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
+ "cache_dir_path": self.cache_dir.path,
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
"drop_last": self.drop_last,
"seed": self.seed,
@@ -438,7 +449,8 @@ def _validate_state_dict(self) -> None:
# In this case, validate the cache folder is the same.
if _should_replace_path(state["input_dir_path"]):
cache_path = _try_create_cache_dir(
- input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"]
+ input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"],
+ cache_dir=state.get("cache_dir_path"),
)
if cache_path != self.input_dir.path:
raise ValueError(
diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py
index 41e4a6a9..5f071734 100644
--- a/src/litdata/streaming/downloader.py
+++ b/src/litdata/streaming/downloader.py
@@ -68,10 +68,15 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0
):
if self._s5cmd_available:
+ env = None
+ if self._storage_options:
+ env = os.environ.copy()
+ env.update(self._storage_options)
proc = subprocess.Popen(
f"s5cmd cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
+ env=env,
)
proc.wait()
else:
@@ -79,8 +84,6 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
extra_args: Dict[str, Any] = {}
- # try:
- # with FileLock(local_filepath + ".lock", timeout=1):
if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py
index acc50d2b..9c58cabf 100644
--- a/src/litdata/streaming/item_loader.py
+++ b/src/litdata/streaming/item_loader.py
@@ -23,9 +23,7 @@
import numpy as np
import torch
-from litdata.constants import (
- _TORCH_DTYPES_MAPPING,
-)
+from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption, EncryptionLevel
@@ -88,7 +86,7 @@ def load_item_from_chunk(
chunk_index: int,
chunk_filepath: str,
begin: int,
- chunk_bytes: int,
+ filesize_bytes: int,
) -> Any:
"""Returns an item loaded from a chunk."""
@@ -132,7 +130,7 @@ def load_item_from_chunk(
chunk_index: int,
chunk_filepath: str,
begin: int,
- chunk_bytes: int,
+ filesize_bytes: int,
encryption: Optional[Encryption] = None,
) -> bytes:
offset = (1 + (index - begin) if index >= begin else index + 1) * 4
@@ -141,11 +139,11 @@ def load_item_from_chunk(
del self._chunk_filepaths[chunk_filepath]
if chunk_filepath not in self._chunk_filepaths:
- exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes
while not exists:
sleep(0.1)
- exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes
self._chunk_filepaths[chunk_filepath] = True
@@ -281,7 +279,17 @@ def setup(
region_of_interest: Optional[List[Tuple[int, int]]] = None,
) -> None:
super().setup(config, chunks, serializers, region_of_interest)
- self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])]
+
+ serializer_name, dtype_index = self._data_format[0].split(":")
+ if serializer_name not in ["no_header_numpy", "no_header_tensor"]:
+ raise ValueError("The provided data format isn't supported.")
+
+ self._serializer_name = serializer_name
+ self._dtype = (
+ _TORCH_DTYPES_MAPPING[int(dtype_index)] # type: ignore
+ if serializer_name == "no_header_tensor"
+ else _NUMPY_DTYPES_MAPPING[int(dtype_index)]
+ )
if all(chunk["dim"] is None for chunk in self._chunks):
raise ValueError("The provided chunks isn't properly setup.")
@@ -329,7 +337,7 @@ def load_item_from_chunk(
chunk_index: int,
chunk_filepath: str,
begin: int,
- chunk_bytes: int,
+ filesize_bytes: int,
) -> torch.Tensor:
assert self._block_size
@@ -337,11 +345,11 @@ def load_item_from_chunk(
del self._chunk_filepaths[chunk_filepath]
if chunk_filepath not in self._chunk_filepaths:
- exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes
while not exists:
sleep(0.1)
- exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes
self._chunk_filepaths[chunk_filepath] = True
@@ -350,7 +358,12 @@ def load_item_from_chunk(
buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * (index - begin) * self._block_size
- return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
+
+ if self._serializer_name == "no_header_tensor":
+ data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
+ else:
+ data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore
+ return data
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
@@ -360,6 +373,14 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None:
del self._mmaps[chunk_index]
os.remove(chunk_filepath)
+ def close(self, chunk_index: int) -> None:
+ """Release the memory-mapped file for a specific chunk index."""
+ if chunk_index in self._mmaps:
+ self._mmaps[chunk_index]._mmap.close()
+ del self._mmaps[chunk_index]
+ if chunk_index in self._buffers:
+ del self._buffers[chunk_index]
+
@classmethod
def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
return data[0], flattened[0].shape[0]
diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py
index 60cd2965..078c9fa2 100644
--- a/src/litdata/streaming/reader.py
+++ b/src/litdata/streaming/reader.py
@@ -20,7 +20,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from litdata.streaming.config import ChunksConfig, Interval
-from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader
+from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer, _get_serializers
from litdata.utilities.encryption import Encryption
@@ -278,17 +278,16 @@ def read(self, index: ChunkedIndex) -> Any:
self._last_chunk_index = index.chunk_index
# Fetch the element
- chunk_filepath, begin, chunk_bytes = self.config[index]
+ chunk_filepath, begin, filesize_bytes = self.config[index]
if isinstance(self._item_loader, PyTreeLoader):
item = self._item_loader.load_item_from_chunk(
- index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes, self._encryption
+ index.index, index.chunk_index, chunk_filepath, begin, filesize_bytes, self._encryption
)
else:
item = self._item_loader.load_item_from_chunk(
- index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes
+ index.index, index.chunk_index, chunk_filepath, begin, filesize_bytes
)
-
# We need to request deletion after the latest element has been loaded.
# Otherwise, this could trigger segmentation fault error depending on the item loader used.
if (
@@ -302,6 +301,11 @@ def read(self, index: ChunkedIndex) -> Any:
# inform the chunk has been completely consumed
self._prepare_thread.delete([self._last_chunk_index])
+ if index.chunk_index != self._last_chunk_index:
+ # Close the memory-mapped file for the last chunk index
+ if isinstance(self._item_loader, TokensLoader) and self._last_chunk_index is not None:
+ self._item_loader.close(self._last_chunk_index)
+
# track the new chunk index as the latest one
self._last_chunk_index = index.chunk_index
diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py
index 98ce5fef..528c9bfd 100644
--- a/src/litdata/streaming/resolver.py
+++ b/src/litdata/streaming/resolver.py
@@ -346,6 +346,7 @@ def _execute(
num_nodes: int,
machine: Optional["Machine"] = None,
command: Optional[str] = None,
+ interruptible: bool = False,
) -> None:
"""Remotely execute the current operator."""
if not _LIGHTNING_SDK_AVAILABLE:
@@ -370,6 +371,7 @@ def _execute(
teamspace_id=studio._teamspace.id,
cluster_id=studio._studio.cluster_id,
machine=machine or studio._studio_api.get_machine(studio._studio.id, studio._teamspace.id),
+ interruptible=interruptible,
)
has_printed = False
diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py
index f453f538..77d5c2bf 100644
--- a/src/litdata/streaming/serializers.py
+++ b/src/litdata/streaming/serializers.py
@@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np
+import tifffile
import torch
from lightning_utilities.core.imports import RequirementCache
@@ -201,7 +202,7 @@ def deserialize(self, data: bytes) -> torch.Tensor:
return torch.reshape(tensor, shape)
def can_serialize(self, item: torch.Tensor) -> bool:
- return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) > 1
+ return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) != 1
class NoHeaderTensorSerializer(Serializer):
@@ -387,13 +388,33 @@ def can_serialize(self, data: float) -> bool:
return isinstance(data, float)
+class TIFFSerializer(Serializer):
+ """Serializer for TIFF files using tifffile."""
+
+ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
+ if not isinstance(item, str) or not os.path.isfile(item):
+ raise ValueError(f"The item to serialize must be a valid file path. Received: {item}")
+
+ # Read the TIFF file as bytes
+ with open(item, "rb") as f:
+ data = f.read()
+
+ return data, None
+
+ def deserialize(self, data: bytes) -> Any:
+ return tifffile.imread(io.BytesIO(data)) # This is a NumPy array
+
+ def can_serialize(self, item: Any) -> bool:
+ return isinstance(item, str) and os.path.isfile(item) and item.lower().endswith((".tif", ".tiff"))
+
+
_SERIALIZERS = OrderedDict(
**{
"str": StringSerializer(),
"int": IntegerSerializer(),
"float": FloatSerializer(),
"video": VideoSerializer(),
- "tif": FileSerializer(),
+ "tifffile": TIFFSerializer(),
"file": FileSerializer(),
"pil": PILSerializer(),
"jpeg": JPEGSerializer(),
diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py
index a23d9e3f..55d72260 100644
--- a/src/litdata/utilities/dataset_utilities.py
+++ b/src/litdata/utilities/dataset_utilities.py
@@ -8,7 +8,7 @@
import numpy as np
-from litdata.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME
+from litdata.constants import _DEFAULT_CACHE_DIR, _DEFAULT_LIGHTNING_CACHE_DIR, _INDEX_FILENAME
from litdata.streaming.downloader import get_downloader_cls
from litdata.streaming.item_loader import BaseItemLoader, TokensLoader
from litdata.streaming.resolver import Dir, _resolve_dir
@@ -17,6 +17,7 @@
def subsample_streaming_dataset(
input_dir: Dir,
+ cache_dir: Optional[Dir] = None,
item_loader: Optional[BaseItemLoader] = None,
subsample: float = 1.0,
shuffle: bool = False,
@@ -39,7 +40,9 @@ def subsample_streaming_dataset(
# Make sure input_dir contains cache path and remote url
if _should_replace_path(input_dir.path):
cache_path = _try_create_cache_dir(
- input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options
+ input_dir=input_dir.path if input_dir.path else input_dir.url,
+ cache_dir=cache_dir.path if cache_dir else None,
+ storage_options=storage_options,
)
if cache_path is not None:
input_dir.path = cache_path
@@ -137,7 +140,11 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s
shutil.rmtree(input_dir_hash_filepath)
-def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]:
+def _try_create_cache_dir(
+ input_dir: Optional[str],
+ cache_dir: Optional[str] = None,
+ storage_options: Optional[Dict] = {},
+) -> Optional[str]:
resolved_input_dir = _resolve_dir(input_dir)
updated_at = _read_updated_at(resolved_input_dir, storage_options)
@@ -147,13 +154,13 @@ def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Di
dir_url_hash = hashlib.md5((resolved_input_dir.url or "").encode()).hexdigest() # noqa: S324
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
- input_dir_hash_filepath = os.path.join(_DEFAULT_CACHE_DIR, dir_url_hash)
+ input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_CACHE_DIR, dir_url_hash)
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
- input_dir_hash_filepath = os.path.join("/cache", "chunks", dir_url_hash)
+ input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_LIGHTNING_CACHE_DIR, dir_url_hash)
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
os.makedirs(cache_dir, exist_ok=True)
diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py
index 80eec0ba..42482b7e 100644
--- a/tests/processing/test_functions.py
+++ b/tests/processing/test_functions.py
@@ -1,10 +1,15 @@
+import glob
import os
+import random
+import shutil
import sys
+from pathlib import Path
from unittest import mock
import cryptography
import numpy as np
import pytest
+import requests
from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache
@@ -475,3 +480,53 @@ def test_optimize_with_rsa_encryption(tmpdir):
# encryption=rsa,
# mode="overwrite",
# )
+
+
+def tokenize(filename: str):
+ with open(filename, encoding="utf-8") as file:
+ text = file.read()
+ text = text.strip().split(" ")
+ word_to_int = {word: random.randint(1, 1000) for word in set(text)} # noqa: S311
+ tokenized = [word_to_int[word] for word in text]
+
+ yield tokenized
+
+
+@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
+def test_optimize_race_condition(tmpdir):
+ # issue: https://github.com/Lightning-AI/litdata/issues/367
+ # run_commands = [
+ # "mkdir -p tempdir/custom_texts",
+ # "curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output tempdir/custom_texts/book1.txt",
+ # "curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output tempdir/custom_texts/book2.txt",
+ # ]
+ shutil.rmtree(f"{tmpdir}/custom_texts", ignore_errors=True)
+ os.makedirs(f"{tmpdir}/custom_texts", exist_ok=True)
+
+ urls = [
+ "https://www.gutenberg.org/cache/epub/24440/pg24440.txt",
+ "https://www.gutenberg.org/cache/epub/26393/pg26393.txt",
+ ]
+
+ for i, url in enumerate(urls):
+ print(f"downloading {i+1} file")
+ with requests.get(url, stream=True, timeout=10) as r:
+ r.raise_for_status() # Raise an exception for bad status codes
+
+ with open(f"{tmpdir}/custom_texts/book{i+1}.txt", "wb") as f:
+ for chunk in r.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ print("=" * 100)
+
+ train_files = sorted(glob.glob(str(Path(f"{tmpdir}/custom_texts") / "*.txt")))
+ print("=" * 100)
+ print(train_files)
+ print("=" * 100)
+ optimize(
+ fn=tokenize,
+ inputs=train_files,
+ output_dir=f"{tmpdir}/temp",
+ num_workers=1,
+ chunk_bytes="50MB",
+ )
diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py
index d9eb34a2..4440b6cc 100644
--- a/tests/streaming/test_combined.py
+++ b/tests/streaming/test_combined.py
@@ -395,6 +395,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -410,6 +411,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -432,6 +434,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -447,6 +450,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -469,6 +473,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -484,6 +489,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -506,6 +512,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -521,6 +528,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -543,6 +551,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -558,6 +567,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -580,6 +590,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -595,6 +606,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -617,6 +629,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -632,6 +645,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -657,6 +671,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -672,6 +687,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -694,6 +710,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -709,6 +726,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -731,6 +749,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -746,6 +765,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -768,6 +788,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -783,6 +804,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -805,6 +827,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -820,6 +843,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -842,6 +866,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -857,6 +882,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -879,6 +905,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
@@ -894,6 +921,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
+ "cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py
index ef93021c..3b001099 100644
--- a/tests/streaming/test_dataset.py
+++ b/tests/streaming/test_dataset.py
@@ -507,6 +507,28 @@ def test_dataset_for_text_tokens(tmpdir):
break
+@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported")
+def test_dataset_for_text_tokens_with_large_num_chunks(tmpdir):
+ import resource
+
+ resource.setrlimit(resource.RLIMIT_NOFILE, (1024, 1024))
+
+ block_size = 1024
+ cache = Cache(input_dir=str(tmpdir), chunk_bytes="10KB", item_loader=TokensLoader(block_size))
+
+ for i in range(10000):
+ text_ids = torch.randint(0, 10001, (torch.randint(100, 1001, (1,)).item(),)).numpy()
+ cache._add_item(i, text_ids)
+
+ cache.done()
+ cache.merge()
+
+ dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
+
+ for _ in dataset:
+ pass
+
+
def test_dataset_with_1d_array(tmpdir):
seed_everything(42)
diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py
index 7c79afe5..cf2cd34c 100644
--- a/tests/streaming/test_downloader.py
+++ b/tests/streaming/test_downloader.py
@@ -1,6 +1,7 @@
+# ruff: noqa: S604
import os
from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
from litdata.streaming.downloader import (
AzureDownloader,
@@ -21,6 +22,62 @@ def test_s3_downloader_fast(tmpdir, monkeypatch):
popen_mock.wait.assert_called()
+@patch("os.system")
+@patch("subprocess.Popen")
+def test_s3_downloader_with_s5cmd_no_storage_options(popen_mock, system_mock, tmpdir):
+ system_mock.return_value = 0 # Simulates s5cmd being available
+ process_mock = MagicMock()
+ popen_mock.return_value = process_mock
+
+ # Initialize the S3Downloader without storage options
+ downloader = S3Downloader("s3://random_bucket", str(tmpdir), [])
+
+ # Action: Call the download_file method
+ remote_filepath = "s3://random_bucket/sample_file.txt"
+ local_filepath = os.path.join(tmpdir, "sample_file.txt")
+ downloader.download_file(remote_filepath, local_filepath)
+
+ # Assertion: Verify subprocess.Popen was called with correct arguments and no env variables
+ popen_mock.assert_called_once_with(
+ f"s5cmd cp {remote_filepath} {local_filepath}",
+ shell=True,
+ stdout=subprocess.PIPE,
+ env=None,
+ )
+ process_mock.wait.assert_called_once()
+
+
+@patch("os.system")
+@patch("subprocess.Popen")
+def test_s3_downloader_with_s5cmd_with_storage_options(popen_mock, system_mock, tmpdir):
+ system_mock.return_value = 0 # Simulates s5cmd being available
+ process_mock = MagicMock()
+ popen_mock.return_value = process_mock
+
+ storage_options = {"AWS_ACCESS_KEY_ID": "dummy_key", "AWS_SECRET_ACCESS_KEY": "dummy_secret"}
+
+ # Initialize the S3Downloader with storage options
+ downloader = S3Downloader("s3://random_bucket", str(tmpdir), [], storage_options)
+
+ # Action: Call the download_file method
+ remote_filepath = "s3://random_bucket/sample_file.txt"
+ local_filepath = os.path.join(tmpdir, "sample_file.txt")
+ downloader.download_file(remote_filepath, local_filepath)
+
+ # Create expected environment variables by merging the current env with storage_options
+ expected_env = os.environ.copy()
+ expected_env.update(storage_options)
+
+ # Assertion: Verify subprocess.Popen was called with the correct arguments and environment variables
+ popen_mock.assert_called_once_with(
+ f"s5cmd cp {remote_filepath} {local_filepath}",
+ shell=True,
+ stdout=subprocess.PIPE,
+ env=expected_env,
+ )
+ process_mock.wait.assert_called_once()
+
+
@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True)
def test_gcp_downloader(tmpdir, monkeypatch, google_mock):
# Create mock objects
diff --git a/tests/streaming/test_item_loader.py b/tests/streaming/test_item_loader.py
index ecb8e6f8..a5828b24 100644
--- a/tests/streaming/test_item_loader.py
+++ b/tests/streaming/test_item_loader.py
@@ -1,10 +1,11 @@
from unittest.mock import MagicMock
+import numpy as np
import torch
-from litdata.constants import _TORCH_DTYPES_MAPPING
+from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming import Cache
from litdata.streaming.dataset import StreamingDataset
-from litdata.streaming.item_loader import PyTreeLoader
+from litdata.streaming.item_loader import PyTreeLoader, TokensLoader
def test_serializer_setup():
@@ -38,3 +39,30 @@ def test_pytreeloader_with_no_header_tensor_serializer(tmpdir):
item = dataset[i]
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_float]), item["float"])
assert torch.allclose(i * torch.ones(10).to(_TORCH_DTYPES_MAPPING[dtype_index_long]), item["long"])
+
+
+def test_tokensloader_with_no_header_numpy_serializer(tmpdir):
+ cache = Cache(str(tmpdir), chunk_size=512, item_loader=TokensLoader())
+ assert isinstance(cache._reader._item_loader, TokensLoader)
+
+ dtype_index_int32 = 3
+ dtype = _NUMPY_DTYPES_MAPPING[dtype_index_int32]
+
+ for i in range(10):
+ data = np.random.randint(0, 100, size=(256), dtype=dtype)
+ cache._add_item(i, data)
+
+ data_format = [f"no_header_numpy:{dtype_index_int32}"]
+ assert cache._writer.get_config()["data_format"] == data_format
+ cache.done()
+ cache.merge()
+
+ dataset = StreamingDataset(
+ input_dir=str(tmpdir),
+ drop_last=True,
+ item_loader=TokensLoader(block_size=256),
+ )
+
+ for data in dataset:
+ assert data.shape == (256,)
+ assert data.dtype == dtype
diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py
index 90729ffb..8c774c23 100644
--- a/tests/streaming/test_resolver.py
+++ b/tests/streaming/test_resolver.py
@@ -291,6 +291,7 @@ def print_fn(msg, file=None):
"teamspace_id": "teamspace_id",
"cluster_id": "cluster_id",
"machine": "cpu",
+ "interruptible": False,
}
generated_kwargs = (
diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py
index 70d10a41..6228e52f 100644
--- a/tests/streaming/test_serializer.py
+++ b/tests/streaming/test_serializer.py
@@ -15,10 +15,12 @@
import os
import random
import sys
+import tempfile
from unittest import mock
import numpy as np
import pytest
+import tifffile
import torch
from lightning_utilities.core.imports import RequirementCache
from litdata.streaming.serializers import (
@@ -34,6 +36,7 @@
NumpySerializer,
PILSerializer,
TensorSerializer,
+ TIFFSerializer,
VideoSerializer,
_get_serializers,
)
@@ -46,6 +49,7 @@ def seed_everything(random_seed):
_PIL_AVAILABLE = RequirementCache("PIL")
+_TIFFFILE_AVAILABLE = RequirementCache("tifffile")
def test_serializers():
@@ -55,7 +59,7 @@ def test_serializers():
"int",
"float",
"video",
- "tif",
+ "tifffile",
"file",
"pil",
"jpeg",
@@ -252,6 +256,14 @@ def test_deserialize_empty_tensor():
assert torch.equal(t, new_t)
+def test_deserialize_scalar_tensor():
+ serializer = TensorSerializer()
+ t = torch.tensor(0)
+ data, _ = serializer.serialize(t)
+ new_t = serializer.deserialize(data)
+ assert torch.equal(t, new_t)
+
+
def test_deserialize_empty_no_header_tensor():
serializer = NoHeaderTensorSerializer()
t = torch.ones((0,)).int()
@@ -265,3 +277,44 @@ def test_deserialize_empty_no_header_tensor():
serializer.setup(name)
new_t = serializer.deserialize(data)
assert torch.equal(t, new_t)
+
+
+def test_can_serialize_tensor():
+ serializer = TensorSerializer()
+ # Check that the TensorSerializer can serialize scalar valued tensors as well as higher order (>1) Tensors
+ assert serializer.can_serialize(torch.tensor(0))
+ assert serializer.can_serialize(torch.tensor([[0, 0]]))
+ # Check that it does not serialize Tensors of order 1, those are treated by the dedicated NoHeaderTensorSerializer
+ assert not serializer.can_serialize(torch.tensor([0, 0]))
+
+
+@pytest.mark.skipif(not _TIFFFILE_AVAILABLE, reason="Requires: ['tifffile']")
+def test_tiff_serializer():
+ serializer = TIFFSerializer()
+
+ # Create a synthetic multispectral image
+ height, width, bands = 28, 28, 12
+ np_data = np.random.randint(0, 65535, size=(height, width, bands), dtype=np.uint16)
+
+ with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp_file:
+ tifffile.imwrite(tmp_file.name, np_data)
+ file_path = tmp_file.name
+
+ # Test can_serialize
+ assert serializer.can_serialize(file_path)
+
+ # Serialize
+ data, _ = serializer.serialize(file_path)
+ assert isinstance(data, bytes)
+
+ # Deserialize
+ deserialized_data = serializer.deserialize(data)
+ assert isinstance(deserialized_data, np.ndarray)
+ assert deserialized_data.shape == (height, width, bands)
+ assert deserialized_data.dtype == np.uint16
+
+ # Validate data content
+ assert np.array_equal(np_data, deserialized_data)
+
+ # Clean up
+ os.remove(file_path)
diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py
index fc08e6df..03d8d905 100644
--- a/tests/utilities/test_dataset_utilities.py
+++ b/tests/utilities/test_dataset_utilities.py
@@ -42,6 +42,26 @@ def test_try_create_cache_dir():
assert len(makedirs_mock.mock_calls) == 2
+def test_try_create_cache_dir_with_custom_cache_dir(tmpdir):
+ cache_dir = str(tmpdir.join("cache"))
+ with mock.patch.dict(os.environ, {}, clear=True):
+ assert os.path.join(
+ cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "100b8cad7cf2a56f6df78f171f97a1ec"
+ ) in _try_create_cache_dir("any", cache_dir)
+
+ with (
+ mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}),
+ mock.patch("litdata.streaming.dataset.os.makedirs") as makedirs_mock,
+ ):
+ cache_dir_1 = _try_create_cache_dir("", cache_dir)
+ cache_dir_2 = _try_create_cache_dir("ssdf", cache_dir)
+ assert cache_dir_1 != cache_dir_2
+ assert cache_dir_1 == os.path.join(
+ cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "d41d8cd98f00b204e9800998ecf8427e"
+ )
+ assert len(makedirs_mock.mock_calls) == 2
+
+
def test_generate_roi():
my_chunks = [
{"chunk_size": 30},