diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index d4e17311..c7cfaea2 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -18,14 +18,14 @@ jobs: actions-ref: main check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.0 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.2 with: azure-dir: "" check-package: - uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.0 + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.2 with: - actions-ref: v0.11.0 + actions-ref: v0.11.2 import-name: "litdata" artifact-name: dist-packages-${{ github.sha }} testing-matrix: | @@ -35,6 +35,6 @@ jobs: } check-docs: - uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0 + uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.2 with: requirements-file: "requirements/docs.txt" diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 4228b9c6..ff1d9cab 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -62,7 +62,15 @@ jobs: key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }} restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip- + - name: Install package & dependencies on Ubuntu + if: matrix.os == 'ubuntu-latest' + run: | + pip --version + pip install -e '.[extras]' -r requirements/test.txt -U -q --find-links $TORCH_URL + pip list + - name: Install package & dependencies + if: matrix.os != 'ubuntu-latest' run: | pip --version pip install -e . -r requirements/test.txt -U -q --find-links $TORCH_URL diff --git a/README.md b/README.md index 439e1aab..b1f07e25 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,12 @@ Install **LitData** with `pip` pip install litdata ``` +Install **LitData** with the extras + +```bash +pip install 'litdata[extras]' +``` + ## Quick Start ### 1. Prepare Your Data diff --git a/requirements.txt b/requirements.txt index 75908f41..f4a9dd6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,5 @@ -lightning-cloud == 0.5.64 # Must be pinned to ensure compatibility -lightning-utilities >=0.8.0, <0.11.0 torch >=2.1.0 filelock -tqdm numpy -torchvision -pillow -viztracer -pyarrow boto3[crt] +requests diff --git a/requirements/extras.txt b/requirements/extras.txt index e69de29b..5276521d 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -0,0 +1,6 @@ +torchvision +pillow +viztracer +pyarrow +tqdm +lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility diff --git a/requirements/test.txt b/requirements/test.txt index 22696251..758e3e85 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,8 +1,9 @@ -coverage ==7.4.3 +coverage ==7.4.4 pytest ==8.0.2 pytest-cov ==4.1.0 -pytest-timeout ==2.2.0 -pytest-rerunfailures ==12.0 +pytest-timeout ==2.3.1 +pytest-rerunfailures ==14.0 pytest-random-order ==1.1.1 pandas lightning +lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility diff --git a/setup.py b/setup.py index 8e41d40e..441b90d2 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ _PATH_ROOT = os.path.dirname(__file__) _PATH_SOURCE = os.path.join(_PATH_ROOT, "src") -_PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements") +_PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements") def _load_py_module(fname, pkg="litdata"): diff --git a/src/litdata/__init__.py b/src/litdata/__init__.py index 9188b49e..19526299 100644 --- a/src/litdata/__init__.py +++ b/src/litdata/__init__.py @@ -1,6 +1,18 @@ -from lightning_utilities.core.imports import RequirementCache +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from litdata.__about__ import * # noqa: F403 +from litdata.imports import RequirementCache from litdata.processing.functions import map, optimize, walk from litdata.streaming.combined import CombinedStreamingDataset from litdata.streaming.dataloader import StreamingDataLoader diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 90bdd0c2..94bcc59b 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -16,7 +16,8 @@ import numpy as np import torch -from lightning_utilities.core.imports import RequirementCache + +from litdata.imports import RequirementCache _INDEX_FILENAME = "index.json" _DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B @@ -26,7 +27,7 @@ # This is required for full pytree serialization / deserialization support _TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0") _VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") -_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64") +_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud") _BOTO3_AVAILABLE = RequirementCache("boto3") _TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio") _ZSTD_AVAILABLE = RequirementCache("zstd") diff --git a/src/litdata/imports.py b/src/litdata/imports.py new file mode 100644 index 00000000..3d415569 --- /dev/null +++ b/src/litdata/imports.py @@ -0,0 +1,121 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from functools import lru_cache +from importlib.util import find_spec +from typing import Optional, TypeVar + +import pkg_resources +from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +@lru_cache +def package_available(package_name: str) -> bool: + """Check if a package is available in your environment. + + >>> package_available('os') + True + >>> package_available('bla') + False + + """ + try: + return find_spec(package_name) is not None + except ModuleNotFoundError: + return False + + +@lru_cache +def module_available(module_path: str) -> bool: + """Check if a module path is available in your environment. + + >>> module_available('os') + True + >>> module_available('os.bla') + False + >>> module_available('bla.bla') + False + + """ + module_names = module_path.split(".") + if not package_available(module_names[0]): + return False + try: + importlib.import_module(module_path) + except ImportError: + return False + return True + + +class RequirementCache: + """Boolean-like class to check for requirement and module availability. + + Args: + requirement: The requirement to check, version specifiers are allowed. + module: The optional module to try to import if the requirement check fails. + + >>> RequirementCache("torch>=0.1") + Requirement 'torch>=0.1' met + >>> bool(RequirementCache("torch>=0.1")) + True + >>> bool(RequirementCache("torch>100.0")) + False + >>> RequirementCache("torch") + Requirement 'torch' met + >>> bool(RequirementCache("torch")) + True + >>> bool(RequirementCache("unknown_package")) + False + + """ + + def __init__(self, requirement: str, module: Optional[str] = None) -> None: + self.requirement = requirement + self.module = module + + def _check_requirement(self) -> None: + if hasattr(self, "available"): + return + try: + # first try the pkg_resources requirement + pkg_resources.require(self.requirement) + self.available = True + self.message = f"Requirement {self.requirement!r} met" + except Exception as ex: + self.available = False + self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`" + requirement_contains_version_specifier = any(c in self.requirement for c in "=<>") + if not requirement_contains_version_specifier or self.module is not None: + module = self.requirement if self.module is None else self.module + # sometimes `pkg_resources.require()` fails but the module is importable + self.available = module_available(module) + if self.available: + self.message = f"Module {module!r} available" + + def __bool__(self) -> bool: + """Format as bool.""" + self._check_requirement() + return self.available + + def __str__(self) -> str: + """Format as string.""" + self._check_requirement() + return self.message + + def __repr__(self) -> str: + """Format as string.""" + return self.__str__() diff --git a/src/litdata/processing/__init__.py b/src/litdata/processing/__init__.py index e69de29b..27efc081 100644 --- a/src/litdata/processing/__init__.py +++ b/src/litdata/processing/__init__.py @@ -0,0 +1,12 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 4ec03f9f..c2ed1b26 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import concurrent import json import logging @@ -19,16 +32,16 @@ import numpy as np import torch -from tqdm.auto import tqdm as _tqdm from litdata.constants import ( _BOTO3_AVAILABLE, _DEFAULT_FAST_DEV_RUN_ITEMS, _INDEX_FILENAME, _IS_IN_STUDIO, - _LIGHTNING_CLOUD_LATEST, + _LIGHTNING_CLOUD_AVAILABLE, _TORCH_GREATER_EQUAL_2_1_0, ) +from litdata.imports import RequirementCache from litdata.processing.readers import BaseReader, StreamingDataLoaderReader from litdata.processing.utilities import _create_dataset from litdata.streaming import Cache @@ -39,10 +52,15 @@ from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily +_TQDM_AVAILABLE = RequirementCache("tqdm") + +if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm + if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads -if _LIGHTNING_CLOUD_LATEST: +if _LIGHTNING_CLOUD_AVAILABLE: from lightning_cloud.openapi import V1DatasetType @@ -947,15 +965,16 @@ def run(self, data_recipe: DataRecipe) -> None: print("Workers are ready ! Starting data processing...") current_total = 0 - pbar = _tqdm( - desc="Progress", - total=num_items, - smoothing=0, - position=-1, - mininterval=1, - leave=True, - dynamic_ncols=True, - ) + if _TQDM_AVAILABLE: + pbar = _tqdm( + desc="Progress", + total=num_items, + smoothing=0, + position=-1, + mininterval=1, + leave=True, + dynamic_ncols=True, + ) num_nodes = _get_num_nodes() node_rank = _get_node_rank() total_num_items = len(user_items) @@ -973,7 +992,8 @@ def run(self, data_recipe: DataRecipe) -> None: self.workers_tracker[index] = counter new_total = sum(self.workers_tracker.values()) - pbar.update(new_total - current_total) + if _TQDM_AVAILABLE: + pbar.update(new_total - current_total) current_total = new_total if current_total == num_items: @@ -988,7 +1008,8 @@ def run(self, data_recipe: DataRecipe) -> None: if all(not w.is_alive() for w in self.workers): raise RuntimeError("One of the worker has failed") - pbar.close() + if _TQDM_AVAILABLE: + pbar.close() # TODO: Understand why it hangs. if num_nodes == 1: diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index 38d1d5fc..f87c536c 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -1,14 +1,33 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import contextlib import os from abc import ABC, abstractmethod from typing import Any, List -from lightning_utilities.core.imports import RequirementCache -from tqdm import tqdm - +from litdata.imports import RequirementCache from litdata.streaming.dataloader import StreamingDataLoader _PYARROW_AVAILABLE = RequirementCache("pyarrow") +_TQDM_AVAILABLE = RequirementCache("tqdm") + +if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm +else: + + def _tqdm(iterator: Any) -> Any: + yield from iterator class BaseReader(ABC): @@ -79,7 +98,7 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]: table = None parquet_filename = os.path.basename(filepath) - for start in tqdm(range(0, num_rows, self.num_rows)): + for start in _tqdm(range(0, num_rows, self.num_rows)): end = min(start + self.num_rows, num_rows) chunk_filepath = os.path.join(cache_folder, f"{start}_{end}_{parquet_filename}") new_items.append(chunk_filepath) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 1ca7142a..84c18097 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import io import os import urllib @@ -5,12 +18,11 @@ from subprocess import DEVNULL, Popen from typing import Any, Callable, List, Optional, Tuple, Union -from litdata.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST +from litdata.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE -if _LIGHTNING_CLOUD_LATEST: +if _LIGHTNING_CLOUD_AVAILABLE: from lightning_cloud.openapi import ( ProjectIdDatasetsBody, - V1DatasetType, ) from lightning_cloud.openapi.rest import ApiException from lightning_cloud.rest_client import LightningClient @@ -19,7 +31,7 @@ def _create_dataset( input_dir: Optional[str], storage_dir: str, - dataset_type: V1DatasetType, + dataset_type: Any, empty: Optional[bool] = None, size: Optional[int] = None, num_bytes: Optional[str] = None, diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 18d78e31..5d00b97e 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -17,7 +17,6 @@ from litdata.constants import ( _INDEX_FILENAME, - _LIGHTNING_CLOUD_LATEST, _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.streaming.item_loader import BaseItemLoader @@ -60,12 +59,6 @@ def __init__( if not _TORCH_GREATER_EQUAL_2_1_0: raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.") - if not _LIGHTNING_CLOUD_LATEST: - raise ModuleNotFoundError( - "The `lightning-cloud` package in your environement is out-dated." - " Run: `pip install -U lightning-cloud` to resolve this." - ) - input_dir = _resolve_dir(input_dir) self._cache_dir = input_dir.path assert self._cache_dir diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index 28a0566f..354e82a5 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from time import time from typing import Any, Optional diff --git a/src/litdata/streaming/compression.py b/src/litdata/streaming/compression.py index 75a47945..21b30cf8 100644 --- a/src/litdata/streaming/compression.py +++ b/src/litdata/streaming/compression.py @@ -14,15 +14,13 @@ from abc import ABC, abstractmethod from typing import Dict, TypeVar -from lightning_utilities.core.imports import requires - from litdata.constants import _ZSTD_AVAILABLE +TCompressor = TypeVar("TCompressor", bound="Compressor") + if _ZSTD_AVAILABLE: import zstd -TCompressor = TypeVar("TCompressor", bound="Compressor") - class Compressor(ABC): """Base class for compression algorithm.""" @@ -44,9 +42,10 @@ def register(cls, compressors: Dict[str, "Compressor"]) -> None: class ZSTDCompressor(Compressor): """Compressor for the zstd package.""" - @requires("zstd") def __init__(self, level: int) -> None: super().__init__() + if not _ZSTD_AVAILABLE: + raise ModuleNotFoundError() self.level = level self.extension = "zstd" diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 756bbadd..6ee36767 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import shutil import subprocess diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 9bf5560b..938b3cf3 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import datetime import os import re @@ -5,13 +18,14 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import Optional, Union +from typing import Any, Optional, Union from urllib import parse -from lightning_cloud.openapi import V1CloudSpace -from lightning_cloud.rest_client import LightningClient +from litdata.constants import _LIGHTNING_CLOUD_AVAILABLE + +if _LIGHTNING_CLOUD_AVAILABLE: + from lightning_cloud.rest_client import LightningClient -# To avoid adding lightning_utilities as a dependency for now. try: import boto3 import botocore @@ -81,7 +95,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: return Dir(path=dir_path_absolute, url=None) -def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspace: V1CloudSpace) -> bool: +def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspace: Any) -> bool: if cloudspace.name is not None and target_name is not None and cloudspace.name.lower() == target_name.lower(): return True diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 4d7923ce..1e6f8c79 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -21,9 +21,9 @@ import numpy as np import torch -from lightning_utilities.core.imports import RequirementCache from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING +from litdata.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") @@ -36,9 +36,19 @@ from PIL.PngImagePlugin import PngImageFile from PIL.WebPImagePlugin import WebPImageFile else: - Image = None - JpegImageFile = None - PngImageFile = None + + class Image: # type: ignore + Image = None + + class JpegImageFile: # type: ignore + pass + + class PngImageFile: # type: ignore + pass + + class WebPImageFile: # type: ignore + pass + if _TORCH_VISION_AVAILABLE: from torchvision.io import decode_jpeg @@ -71,7 +81,7 @@ def setup(self, metadata: Any) -> None: class PILSerializer(Serializer): """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" - def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: mode = item.mode.encode("utf-8") width, height = item.size raw = item.tobytes() @@ -95,7 +105,7 @@ def can_serialize(self, item: Any) -> bool: class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" - def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]: + def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: if isinstance(item, JpegImageFile): if not hasattr(item, "filename"): raise ValueError( diff --git a/src/litdata/utilities/__init__.py b/src/litdata/utilities/__init__.py index e69de29b..27efc081 100644 --- a/src/litdata/utilities/__init__.py +++ b/src/litdata/utilities/__init__.py @@ -0,0 +1,12 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index e30e9e7b..7276a87e 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from typing import Callable, Optional diff --git a/src/litdata/utilities/format.py b/src/litdata/utilities/format.py index 46948aa6..946b1063 100644 --- a/src/litdata/utilities/format.py +++ b/src/litdata/utilities/format.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + _FORMAT_TO_RATIO = { "kb": 1000, "mb": 1000**2, diff --git a/src/litdata/utilities/packing.py b/src/litdata/utilities/packing.py index 309a32d7..3b1c8480 100644 --- a/src/litdata/utilities/packing.py +++ b/src/litdata/utilities/packing.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import defaultdict from typing import Any, Dict, List, Tuple diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 8b24a6de..14160d1e 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any, List, Tuple import numpy as np diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index aa3465bb..0e4e5372 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -9,9 +9,8 @@ import numpy as np import pytest import torch -from lightning import seed_everything -from lightning_utilities.core.imports import RequirementCache from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE +from litdata.imports import RequirementCache from litdata.processing import data_processor as data_processor_module from litdata.processing import functions from litdata.processing.data_processor import ( @@ -33,6 +32,13 @@ from litdata.streaming import StreamingDataLoader, StreamingDataset, resolver from litdata.streaming.cache import Cache, Dir + +def seed_everything(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + _PIL_AVAILABLE = RequirementCache("PIL") diff --git a/tests/streaming/test_cache.py b/tests/streaming/test_cache.py index 5df649b2..9bb97e65 100644 --- a/tests/streaming/test_cache.py +++ b/tests/streaming/test_cache.py @@ -12,17 +12,16 @@ # limitations under the License. import json import os +import random import sys from functools import partial import numpy as np import pytest import torch -from lightning import seed_everything -from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import RandomDataset -from lightning_utilities.core.imports import RequirementCache from lightning_utilities.test.warning import no_warning_call +from litdata.imports import RequirementCache from litdata.streaming import Cache from litdata.streaming.dataloader import CacheDataLoader from litdata.streaming.dataset import StreamingDataset @@ -31,8 +30,16 @@ from litdata.utilities.env import _DistributedEnv from torch.utils.data import Dataset + +def seed_everything(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") +_LIGHTNING_AVAILABLE = RequirementCache("lightning") class ImageDataset(Dataset): @@ -145,7 +152,7 @@ def _fabric_cache_for_image_dataset(fabric, num_workers, tmpdir): @pytest.mark.skipif( - condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE or sys.platform == "win32", + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE or sys.platform == "win32" or not _LIGHTNING_AVAILABLE, reason="Requires: ['pil', 'torchvision']", ) @pytest.mark.parametrize("num_workers", [2]) @@ -153,6 +160,8 @@ def test_cache_for_image_dataset_distributed(num_workers, tmpdir): cache_dir = os.path.join(tmpdir, "cache") os.makedirs(cache_dir) + from lightning.fabric import Fabric + fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn") fabric.launch(partial(_fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir)) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 056826a9..d719bce6 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -2,6 +2,7 @@ import pytest import torch +from litdata.constants import _VIZ_TRACKER_AVAILABLE from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader from litdata.streaming import dataloader as streaming_dataloader_module from torch import tensor @@ -80,6 +81,7 @@ def test_streaming_dataloader(): } +@pytest.mark.skipif(not _VIZ_TRACKER_AVAILABLE, reason="viz tracker required") @pytest.mark.parametrize("profile", [2, True]) def test_dataloader_profiling(profile, tmpdir, monkeypatch): monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index eda5797d..b1786931 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -12,6 +12,7 @@ # limitations under the License. import os +import random import sys from time import sleep from unittest import mock @@ -19,7 +20,6 @@ import numpy as np import pytest import torch -from lightning import seed_everything from litdata.processing import functions from litdata.streaming import Cache from litdata.streaming import dataset as dataset_module @@ -40,6 +40,12 @@ from torch.utils.data import DataLoader +def seed_everything(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + def test_streaming_dataset(tmpdir, monkeypatch): seed_everything(42) diff --git a/tests/streaming/test_sampler.py b/tests/streaming/test_sampler.py index 426ae06d..d640960d 100644 --- a/tests/streaming/test_sampler.py +++ b/tests/streaming/test_sampler.py @@ -1,10 +1,18 @@ +import random from unittest import mock +import numpy as np import pytest -from lightning import seed_everything +import torch from litdata.streaming.sampler import CacheBatchSampler +def seed_everything(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + @pytest.mark.parametrize( "params", [ diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index 28c66de9..08357925 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -13,14 +13,14 @@ import io import os +import random import sys from time import time import numpy as np import pytest import torch -from lightning import seed_everything -from lightning_utilities.core.imports import RequirementCache +from litdata.imports import RequirementCache from litdata.streaming.serializers import ( _AV_AVAILABLE, _NUMPY_DTYPES_MAPPING, @@ -38,6 +38,13 @@ VideoSerializer, ) + +def seed_everything(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + _PIL_AVAILABLE = RequirementCache("PIL") diff --git a/tests/streaming/test_writer.py b/tests/streaming/test_writer.py index 3b9e6ff0..807442c5 100644 --- a/tests/streaming/test_writer.py +++ b/tests/streaming/test_writer.py @@ -13,18 +13,26 @@ import json import os +import random import sys import numpy as np import pytest -from lightning import seed_everything -from lightning_utilities.core.imports import RequirementCache +import torch +from litdata.imports import RequirementCache from litdata.streaming.compression import _ZSTD_AVAILABLE from litdata.streaming.reader import BinaryReader from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.writer import BinaryWriter from litdata.utilities.format import _FORMAT_TO_RATIO + +def seed_everything(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + _PIL_AVAILABLE = RequirementCache("PIL")