Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 3, 2024
2 parents 1392569 + 015a546 commit 7af430c
Show file tree
Hide file tree
Showing 32 changed files with 428 additions and 75 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ jobs:
actions-ref: main

check-schema:
uses: Lightning-AI/utilities/.github/workflows/[email protected].0
uses: Lightning-AI/utilities/.github/workflows/[email protected].2
with:
azure-dir: ""

check-package:
uses: Lightning-AI/utilities/.github/workflows/[email protected].0
uses: Lightning-AI/utilities/.github/workflows/[email protected].2
with:
actions-ref: v0.11.0
actions-ref: v0.11.2
import-name: "litdata"
artifact-name: dist-packages-${{ github.sha }}
testing-matrix: |
Expand All @@ -35,6 +35,6 @@ jobs:
}
check-docs:
uses: Lightning-AI/utilities/.github/workflows/[email protected].0
uses: Lightning-AI/utilities/.github/workflows/[email protected].2
with:
requirements-file: "requirements/docs.txt"
8 changes: 8 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ jobs:
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-

- name: Install package & dependencies on Ubuntu
if: matrix.os == 'ubuntu-latest'
run: |
pip --version
pip install -e '.[extras]' -r requirements/test.txt -U -q --find-links $TORCH_URL
pip list
- name: Install package & dependencies
if: matrix.os != 'ubuntu-latest'
run: |
pip --version
pip install -e . -r requirements/test.txt -U -q --find-links $TORCH_URL
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ Install **LitData** with `pip`
pip install litdata
```

Install **LitData** with the extras

```bash
pip install 'litdata[extras]'
```

## Quick Start

### 1. Prepare Your Data
Expand Down
8 changes: 1 addition & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
lightning-cloud == 0.5.64 # Must be pinned to ensure compatibility
lightning-utilities >=0.8.0, <0.11.0
torch >=2.1.0
filelock
tqdm
numpy
torchvision
pillow
viztracer
pyarrow
boto3[crt]
requests
6 changes: 6 additions & 0 deletions requirements/extras.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torchvision
pillow
viztracer
pyarrow
tqdm
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
7 changes: 4 additions & 3 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
coverage ==7.4.3
coverage ==7.4.4
pytest ==8.0.2
pytest-cov ==4.1.0
pytest-timeout ==2.2.0
pytest-rerunfailures ==12.0
pytest-timeout ==2.3.1
pytest-rerunfailures ==14.0
pytest-random-order ==1.1.1
pandas
lightning
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_PATH_ROOT = os.path.dirname(__file__)
_PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements")
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements")


def _load_py_module(fname, pkg="litdata"):
Expand Down
14 changes: 13 additions & 1 deletion src/litdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from lightning_utilities.core.imports import RequirementCache
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from litdata.__about__ import * # noqa: F403
from litdata.imports import RequirementCache
from litdata.processing.functions import map, optimize, walk
from litdata.streaming.combined import CombinedStreamingDataset
from litdata.streaming.dataloader import StreamingDataLoader
Expand Down
5 changes: 3 additions & 2 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache

from litdata.imports import RequirementCache

_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
Expand All @@ -26,7 +27,7 @@
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64")
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
Expand Down
121 changes: 121 additions & 0 deletions src/litdata/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from functools import lru_cache
from importlib.util import find_spec
from typing import Optional, TypeVar

import pkg_resources
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")


@lru_cache
def package_available(package_name: str) -> bool:
"""Check if a package is available in your environment.
>>> package_available('os')
True
>>> package_available('bla')
False
"""
try:
return find_spec(package_name) is not None
except ModuleNotFoundError:
return False


@lru_cache
def module_available(module_path: str) -> bool:
"""Check if a module path is available in your environment.
>>> module_available('os')
True
>>> module_available('os.bla')
False
>>> module_available('bla.bla')
False
"""
module_names = module_path.split(".")
if not package_available(module_names[0]):
return False
try:
importlib.import_module(module_path)
except ImportError:
return False
return True


class RequirementCache:
"""Boolean-like class to check for requirement and module availability.
Args:
requirement: The requirement to check, version specifiers are allowed.
module: The optional module to try to import if the requirement check fails.
>>> RequirementCache("torch>=0.1")
Requirement 'torch>=0.1' met
>>> bool(RequirementCache("torch>=0.1"))
True
>>> bool(RequirementCache("torch>100.0"))
False
>>> RequirementCache("torch")
Requirement 'torch' met
>>> bool(RequirementCache("torch"))
True
>>> bool(RequirementCache("unknown_package"))
False
"""

def __init__(self, requirement: str, module: Optional[str] = None) -> None:
self.requirement = requirement
self.module = module

def _check_requirement(self) -> None:
if hasattr(self, "available"):
return
try:
# first try the pkg_resources requirement
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
if not requirement_contains_version_specifier or self.module is not None:
module = self.requirement if self.module is None else self.module
# sometimes `pkg_resources.require()` fails but the module is importable
self.available = module_available(module)
if self.available:
self.message = f"Module {module!r} available"

def __bool__(self) -> bool:
"""Format as bool."""
self._check_requirement()
return self.available

def __str__(self) -> str:
"""Format as string."""
self._check_requirement()
return self.message

def __repr__(self) -> str:
"""Format as string."""
return self.__str__()
12 changes: 12 additions & 0 deletions src/litdata/processing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
49 changes: 35 additions & 14 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent
import json
import logging
Expand All @@ -19,16 +32,16 @@

import numpy as np
import torch
from tqdm.auto import tqdm as _tqdm

from litdata.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_LATEST,
_LIGHTNING_CLOUD_AVAILABLE,
_TORCH_GREATER_EQUAL_2_1_0,
)
from litdata.imports import RequirementCache
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
from litdata.processing.utilities import _create_dataset
from litdata.streaming import Cache
Expand All @@ -39,10 +52,15 @@
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.packing import _pack_greedily

_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads

if _LIGHTNING_CLOUD_LATEST:
if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import V1DatasetType


Expand Down Expand Up @@ -947,15 +965,16 @@ def run(self, data_recipe: DataRecipe) -> None:
print("Workers are ready ! Starting data processing...")

current_total = 0
pbar = _tqdm(
desc="Progress",
total=num_items,
smoothing=0,
position=-1,
mininterval=1,
leave=True,
dynamic_ncols=True,
)
if _TQDM_AVAILABLE:
pbar = _tqdm(
desc="Progress",
total=num_items,
smoothing=0,
position=-1,
mininterval=1,
leave=True,
dynamic_ncols=True,
)
num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
total_num_items = len(user_items)
Expand All @@ -973,7 +992,8 @@ def run(self, data_recipe: DataRecipe) -> None:
self.workers_tracker[index] = counter
new_total = sum(self.workers_tracker.values())

pbar.update(new_total - current_total)
if _TQDM_AVAILABLE:
pbar.update(new_total - current_total)

current_total = new_total
if current_total == num_items:
Expand All @@ -988,7 +1008,8 @@ def run(self, data_recipe: DataRecipe) -> None:
if all(not w.is_alive() for w in self.workers):
raise RuntimeError("One of the worker has failed")

pbar.close()
if _TQDM_AVAILABLE:
pbar.close()

# TODO: Understand why it hangs.
if num_nodes == 1:
Expand Down
Loading

0 comments on commit 7af430c

Please sign in to comment.