Skip to content

Commit

Permalink
Refactor optional imports
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 11, 2024
1 parent c4c9117 commit 897caab
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 110 deletions.
2 changes: 1 addition & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_ZSTD_AVAILABLE = RequirementCache("zstd")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_TQDM_AVAILABLE = RequirementCache("tqdm")

_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
18 changes: 6 additions & 12 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@
import torch

from litdata.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_ENABLE_STATUS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_AVAILABLE,
_TQDM_AVAILABLE,
)
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
Expand All @@ -53,16 +51,8 @@
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.packing import _pack_greedily

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import V1DatasetType

if _BOTO3_AVAILABLE:
import boto3
import botocore
import boto3
import botocore

logger = logging.Logger(__name__)

Expand Down Expand Up @@ -1043,6 +1033,8 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = 0
if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

pbar = _tqdm(
desc="Progress",
total=num_items,
Expand Down Expand Up @@ -1097,6 +1089,8 @@ def run(self, data_recipe: DataRecipe) -> None:
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO:
from lightning_cloud.openapi import V1DatasetType

_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
10 changes: 3 additions & 7 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,7 @@
_resolve_dir,
)
from litdata.utilities._pytree import tree_flatten

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator
from litdata.utilities.format import _get_tqdm_iterator_if_available


def _is_remote_file(path: str) -> bool:
Expand Down Expand Up @@ -563,6 +557,8 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore

_tqdm = _get_tqdm_iterator_if_available()

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)

Expand Down
11 changes: 3 additions & 8 deletions src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,12 @@
from abc import ABC, abstractmethod
from typing import Any, List

from litdata.constants import _TQDM_AVAILABLE
from litdata.imports import RequirementCache
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.utilities.format import _get_tqdm_iterator_if_available

_PYARROW_AVAILABLE = RequirementCache("pyarrow")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator


class BaseReader(ABC):
def get_num_nodes(self) -> int:
Expand Down Expand Up @@ -92,6 +85,8 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]:
cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}")
os.makedirs(cache_folder, exist_ok=True)

_tqdm = _get_tqdm_iterator_if_available()

for filepath in filepaths:
num_rows = self._get_num_rows(filepath)

Expand Down
22 changes: 7 additions & 15 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib import parse

from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.streaming.cache import Dir

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import (
ProjectIdDatasetsBody,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

try:
import boto3
import botocore

_BOTO3_AVAILABLE = True
except Exception:
_BOTO3_AVAILABLE = False
import boto3
import botocore


def _create_dataset(
Expand Down Expand Up @@ -67,6 +55,10 @@ def _create_dataset(
if not storage_dir:
raise ValueError("The storage_dir should be defined.")

from lightning_cloud.openapi import ProjectIdDatasetsBody
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

client = LightningClient(retry=False)

try:
Expand Down
11 changes: 5 additions & 6 deletions src/litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from time import time
from typing import Any, Optional

from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO
from litdata.constants import _IS_IN_STUDIO

if _BOTO3_AVAILABLE:
import boto3
import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher
import boto3
import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher


class S3Client:
Expand Down
9 changes: 5 additions & 4 deletions src/litdata/streaming/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

TCompressor = TypeVar("TCompressor", bound="Compressor")

if _ZSTD_AVAILABLE:
import zstd


class Compressor(ABC):
"""Base class for compression algorithm."""
Expand All @@ -45,7 +42,7 @@ class ZSTDCompressor(Compressor):
def __init__(self, level: int) -> None:
super().__init__()
if not _ZSTD_AVAILABLE:
raise ModuleNotFoundError()
raise ModuleNotFoundError(str(_ZSTD_AVAILABLE))
self.level = level
self.extension = "zstd"

Expand All @@ -54,9 +51,13 @@ def name(self) -> str:
return f"{self.extension}:{self.level}"

def compress(self, data: bytes) -> bytes:
import zstd

return zstd.compress(data, self.level)

def decompress(self, data: bytes) -> bytes:
import zstd

return zstd.decompress(data)

@classmethod
Expand Down
12 changes: 5 additions & 7 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@
from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
from litdata.streaming.client import S3Client

if _GOOGLE_STORAGE_AVAILABLE:
from google.cloud import storage
else:

class storage: # type: ignore
Client = None


class Downloader(ABC):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
Expand Down Expand Up @@ -96,9 +89,14 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:

class GCPDownloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
if not _GOOGLE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))

super().__init__(remote_dir, cache_dir, chunks)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
from google.cloud import storage

obj = parse.urlparse(remote_filepath)

if obj.scheme != "gs":
Expand Down
37 changes: 14 additions & 23 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,16 @@
from dataclasses import dataclass
from pathlib import Path
from time import sleep
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Union, TYPE_CHECKING
from urllib import parse

from litdata.constants import _LIGHTNING_CLOUD_AVAILABLE
from litdata.constants import _LIGHTNING_SDK_AVAILABLE

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.rest_client import LightningClient

try:
import boto3
import botocore

_BOTO3_AVAILABLE = True
except Exception:
_BOTO3_AVAILABLE = False


try:
from lightning_sdk import Machine, Studio

_LIGHTNING_SDK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):

class Machine: # type: ignore
pass
import boto3
import botocore

_LIGHTNING_SDK_AVAILABLE = False
if TYPE_CHECKING:
from lightning_sdk import Machine


@dataclass
Expand Down Expand Up @@ -115,6 +98,8 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa


def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir:
from lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

# Get the ids from env variables
Expand Down Expand Up @@ -154,6 +139,8 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option


def _resolve_s3_connections(dir_path: str) -> Dir:
from lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

# Get the ids from env variables
Expand All @@ -174,6 +161,8 @@ def _resolve_s3_connections(dir_path: str) -> Dir:


def _resolve_datasets(dir_path: str) -> Dir:
from lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

# Get the ids from env variables
Expand Down Expand Up @@ -362,6 +351,8 @@ def _execute(
if not _LIGHTNING_SDK_AVAILABLE:
raise ModuleNotFoundError("The `lightning_sdk` is required.")

from lightning_sdk import Studio

lightning_skip_install = os.getenv("LIGHTNING_SKIP_INSTALL", "")
if lightning_skip_install:
lightning_skip_install = f" LIGHTNING_SKIP_INSTALL={lightning_skip_install} "
Expand Down
Loading

0 comments on commit 897caab

Please sign in to comment.