From 897caab2b207bea143107c7063d82486fd4b7fd8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 11 Jul 2024 12:44:39 +0200 Subject: [PATCH] Refactor optional imports --- src/litdata/constants.py | 2 +- src/litdata/processing/data_processor.py | 18 +++------ src/litdata/processing/functions.py | 10 ++--- src/litdata/processing/readers.py | 11 ++---- src/litdata/processing/utilities.py | 22 ++++------- src/litdata/streaming/client.py | 11 +++--- src/litdata/streaming/compression.py | 9 +++-- src/litdata/streaming/downloader.py | 12 +++--- src/litdata/streaming/resolver.py | 37 +++++++----------- src/litdata/streaming/serializers.py | 48 +++++++++++------------- src/litdata/utilities/format.py | 15 ++++++++ 11 files changed, 85 insertions(+), 110 deletions(-) diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 8befa208..6caa7520 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -33,7 +33,7 @@ _ZSTD_AVAILABLE = RequirementCache("zstd") _GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage") _TQDM_AVAILABLE = RequirementCache("tqdm") - +_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 53cf3350..8fef6d9b 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -34,12 +34,10 @@ import torch from litdata.constants import ( - _BOTO3_AVAILABLE, _DEFAULT_FAST_DEV_RUN_ITEMS, _ENABLE_STATUS, _INDEX_FILENAME, _IS_IN_STUDIO, - _LIGHTNING_CLOUD_AVAILABLE, _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader @@ -53,16 +51,8 @@ from litdata.utilities.broadcast import broadcast_object from litdata.utilities.dataset_utilities import load_index_file from litdata.utilities.packing import _pack_greedily - -if _TQDM_AVAILABLE: - from tqdm.auto import tqdm as _tqdm - -if _LIGHTNING_CLOUD_AVAILABLE: - from lightning_cloud.openapi import V1DatasetType - -if _BOTO3_AVAILABLE: - import boto3 - import botocore +import boto3 +import botocore logger = logging.Logger(__name__) @@ -1043,6 +1033,8 @@ def run(self, data_recipe: DataRecipe) -> None: current_total = 0 if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm + pbar = _tqdm( desc="Progress", total=num_items, @@ -1097,6 +1089,8 @@ def run(self, data_recipe: DataRecipe) -> None: result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir) if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO: + from lightning_cloud.openapi import V1DatasetType + _create_dataset( input_dir=self.input_dir.path, storage_dir=self.output_dir.path, diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index adb5cdeb..a127a108 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -46,13 +46,7 @@ _resolve_dir, ) from litdata.utilities._pytree import tree_flatten - -if _TQDM_AVAILABLE: - from tqdm.auto import tqdm as _tqdm -else: - - def _tqdm(iterator: Any) -> Any: - yield from iterator +from litdata.utilities.format import _get_tqdm_iterator_if_available def _is_remote_file(path: str) -> bool: @@ -563,6 +557,8 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore + _tqdm = _get_tqdm_iterator_if_available() + for copy_info in _tqdm(copy_infos): _apply_copy(copy_info, resolved_output_dir) diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index eb3c1511..f67d0f88 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -16,19 +16,12 @@ from abc import ABC, abstractmethod from typing import Any, List -from litdata.constants import _TQDM_AVAILABLE from litdata.imports import RequirementCache from litdata.streaming.dataloader import StreamingDataLoader +from litdata.utilities.format import _get_tqdm_iterator_if_available _PYARROW_AVAILABLE = RequirementCache("pyarrow") -if _TQDM_AVAILABLE: - from tqdm.auto import tqdm as _tqdm -else: - - def _tqdm(iterator: Any) -> Any: - yield from iterator - class BaseReader(ABC): def get_num_nodes(self) -> int: @@ -92,6 +85,8 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]: cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}") os.makedirs(cache_folder, exist_ok=True) + _tqdm = _get_tqdm_iterator_if_available() + for filepath in filepaths: num_rows = self._get_num_rows(filepath) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 8312ea39..2b193e48 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -21,23 +21,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO from litdata.streaming.cache import Dir -if _LIGHTNING_CLOUD_AVAILABLE: - from lightning_cloud.openapi import ( - ProjectIdDatasetsBody, - ) - from lightning_cloud.openapi.rest import ApiException - from lightning_cloud.rest_client import LightningClient - -try: - import boto3 - import botocore - - _BOTO3_AVAILABLE = True -except Exception: - _BOTO3_AVAILABLE = False +import boto3 +import botocore def _create_dataset( @@ -67,6 +55,10 @@ def _create_dataset( if not storage_dir: raise ValueError("The storage_dir should be defined.") + from lightning_cloud.openapi import ProjectIdDatasetsBody + from lightning_cloud.openapi.rest import ApiException + from lightning_cloud.rest_client import LightningClient + client = LightningClient(retry=False) try: diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index 354e82a5..37d75d61 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -15,13 +15,12 @@ from time import time from typing import Any, Optional -from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO +from litdata.constants import _IS_IN_STUDIO -if _BOTO3_AVAILABLE: - import boto3 - import botocore - from botocore.credentials import InstanceMetadataProvider - from botocore.utils import InstanceMetadataFetcher +import boto3 +import botocore +from botocore.credentials import InstanceMetadataProvider +from botocore.utils import InstanceMetadataFetcher class S3Client: diff --git a/src/litdata/streaming/compression.py b/src/litdata/streaming/compression.py index 21b30cf8..248f0781 100644 --- a/src/litdata/streaming/compression.py +++ b/src/litdata/streaming/compression.py @@ -18,9 +18,6 @@ TCompressor = TypeVar("TCompressor", bound="Compressor") -if _ZSTD_AVAILABLE: - import zstd - class Compressor(ABC): """Base class for compression algorithm.""" @@ -45,7 +42,7 @@ class ZSTDCompressor(Compressor): def __init__(self, level: int) -> None: super().__init__() if not _ZSTD_AVAILABLE: - raise ModuleNotFoundError() + raise ModuleNotFoundError(str(_ZSTD_AVAILABLE)) self.level = level self.extension = "zstd" @@ -54,9 +51,13 @@ def name(self) -> str: return f"{self.extension}:{self.level}" def compress(self, data: bytes) -> bytes: + import zstd + return zstd.compress(data, self.level) def decompress(self, data: bytes) -> bytes: + import zstd + return zstd.decompress(data) @classmethod diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index da5c7a9e..eba85c5e 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -23,13 +23,6 @@ from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME from litdata.streaming.client import S3Client -if _GOOGLE_STORAGE_AVAILABLE: - from google.cloud import storage -else: - - class storage: # type: ignore - Client = None - class Downloader(ABC): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): @@ -96,9 +89,14 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: class GCPDownloader(Downloader): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): + if not _GOOGLE_STORAGE_AVAILABLE: + raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) + super().__init__(remote_dir, cache_dir, chunks) def download_file(self, remote_filepath: str, local_filepath: str) -> None: + from google.cloud import storage + obj = parse.urlparse(remote_filepath) if obj.scheme != "gs": diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 633ebc05..d7c6971d 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,33 +20,16 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Union, TYPE_CHECKING from urllib import parse -from litdata.constants import _LIGHTNING_CLOUD_AVAILABLE +from litdata.constants import _LIGHTNING_SDK_AVAILABLE -if _LIGHTNING_CLOUD_AVAILABLE: - from lightning_cloud.rest_client import LightningClient - -try: - import boto3 - import botocore - - _BOTO3_AVAILABLE = True -except Exception: - _BOTO3_AVAILABLE = False - - -try: - from lightning_sdk import Machine, Studio - - _LIGHTNING_SDK_AVAILABLE = True -except (ImportError, ModuleNotFoundError): - - class Machine: # type: ignore - pass +import boto3 +import botocore - _LIGHTNING_SDK_AVAILABLE = False +if TYPE_CHECKING: + from lightning_sdk import Machine @dataclass @@ -115,6 +98,8 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: + from lightning_cloud.rest_client import LightningClient + client = LightningClient(max_tries=2) # Get the ids from env variables @@ -154,6 +139,8 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option def _resolve_s3_connections(dir_path: str) -> Dir: + from lightning_cloud.rest_client import LightningClient + client = LightningClient(max_tries=2) # Get the ids from env variables @@ -174,6 +161,8 @@ def _resolve_s3_connections(dir_path: str) -> Dir: def _resolve_datasets(dir_path: str) -> Dir: + from lightning_cloud.rest_client import LightningClient + client = LightningClient(max_tries=2) # Get the ids from env variables @@ -362,6 +351,8 @@ def _execute( if not _LIGHTNING_SDK_AVAILABLE: raise ModuleNotFoundError("The `lightning_sdk` is required.") + from lightning_sdk import Studio + lightning_skip_install = os.getenv("LIGHTNING_SKIP_INSTALL", "") if lightning_skip_install: lightning_skip_install = f" LIGHTNING_SKIP_INSTALL={lightning_skip_install} " diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 87928fa5..6e0f32ac 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, TYPE_CHECKING import numpy as np import torch @@ -26,35 +26,13 @@ from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING from litdata.imports import RequirementCache +if TYPE_CHECKING: + from PIL.JpegImagePlugin import JpegImageFile + _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") -if _PIL_AVAILABLE: - from PIL import Image - from PIL.GifImagePlugin import GifImageFile - from PIL.JpegImagePlugin import JpegImageFile - from PIL.PngImagePlugin import PngImageFile - from PIL.WebPImagePlugin import WebPImageFile -else: - - class Image: # type: ignore - Image = None - - class JpegImageFile: # type: ignore - pass - - class PngImageFile: # type: ignore - pass - - class WebPImageFile: # type: ignore - pass - - -if _TORCH_VISION_AVAILABLE: - from torchvision.io import decode_jpeg - from torchvision.transforms.functional import pil_to_tensor - class Serializer(ABC): """The base interface for any serializers. @@ -91,6 +69,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: @classmethod def deserialize(cls, data: bytes) -> Any: + from PIL import Image + idx = 3 * 4 width, height, mode_size = np.frombuffer(data[:idx], np.uint32) idx2 = idx + mode_size @@ -100,6 +80,9 @@ def deserialize(cls, data: bytes) -> Any: return Image.frombytes(mode, size, raw) # pyright: ignore def can_serialize(self, item: Any) -> bool: + from PIL import Image + from PIL.JpegImagePlugin import JpegImageFile + return bool(_PIL_AVAILABLE) and isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) @@ -107,6 +90,12 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + from PIL import Image + from PIL.GifImagePlugin import GifImageFile + from PIL.JpegImagePlugin import JpegImageFile + from PIL.PngImagePlugin import PngImageFile + from PIL.WebPImagePlugin import WebPImageFile + if isinstance(item, JpegImageFile): if not hasattr(item, "filename"): raise ValueError( @@ -130,8 +119,11 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.") - def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: + def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: if _TORCH_VISION_AVAILABLE: + from torchvision.io import decode_jpeg + from torchvision.transforms.functional import pil_to_tensor + array = torch.frombuffer(data, dtype=torch.uint8) try: return decode_jpeg(array) @@ -145,6 +137,8 @@ def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: return img def can_serialize(self, item: Any) -> bool: + from PIL.JpegImagePlugin import JpegImageFile + return bool(_PIL_AVAILABLE) and isinstance(item, JpegImageFile) diff --git a/src/litdata/utilities/format.py b/src/litdata/utilities/format.py index 946b1063..ec9db747 100644 --- a/src/litdata/utilities/format.py +++ b/src/litdata/utilities/format.py @@ -10,6 +10,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + +from litdata.constants import _TQDM_AVAILABLE _FORMAT_TO_RATIO = { "kb": 1000, @@ -40,3 +43,15 @@ def _human_readable_bytes(num_bytes: float) -> str: return f"{num_bytes:3.1f} {unit}" num_bytes /= 1000.0 return f"{num_bytes:.1f} PB" + + +def _get_tqdm_iterator_if_available(): + if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm + + return _tqdm + + def _tqdm(iterator: Any) -> Any: + yield from iterator + + return _tqdm