Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor optional imports #221

Merged
merged 7 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_ZSTD_AVAILABLE = RequirementCache("zstd")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_TQDM_AVAILABLE = RequirementCache("tqdm")

_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
18 changes: 6 additions & 12 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from urllib import parse

import boto3
import botocore
tchaton marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import torch

from litdata.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_ENABLE_STATUS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_AVAILABLE,
_TQDM_AVAILABLE,
)
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
Expand All @@ -54,16 +54,6 @@
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.packing import _pack_greedily

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import V1DatasetType

if _BOTO3_AVAILABLE:
import boto3
import botocore

logger = logging.Logger(__name__)


Expand Down Expand Up @@ -1043,6 +1033,8 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = 0
if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

pbar = _tqdm(
desc="Progress",
total=num_items,
Expand Down Expand Up @@ -1097,6 +1089,8 @@ def run(self, data_recipe: DataRecipe) -> None:
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO:
from lightning_cloud.openapi import V1DatasetType

_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
12 changes: 4 additions & 8 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import torch

from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import (
Expand All @@ -46,13 +46,7 @@
_resolve_dir,
)
from litdata.utilities._pytree import tree_flatten

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator
from litdata.utilities.format import _get_tqdm_iterator_if_available


def _is_remote_file(path: str) -> bool:
Expand Down Expand Up @@ -563,6 +557,8 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore

_tqdm = _get_tqdm_iterator_if_available()

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)

Expand Down
11 changes: 3 additions & 8 deletions src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,12 @@
from abc import ABC, abstractmethod
from typing import Any, List

from litdata.constants import _TQDM_AVAILABLE
from litdata.imports import RequirementCache
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.utilities.format import _get_tqdm_iterator_if_available

_PYARROW_AVAILABLE = RequirementCache("pyarrow")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator


class BaseReader(ABC):
def get_num_nodes(self) -> int:
Expand Down Expand Up @@ -92,6 +85,8 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]:
cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}")
os.makedirs(cache_folder, exist_ok=True)

_tqdm = _get_tqdm_iterator_if_available()

for filepath in filepaths:
num_rows = self._get_num_rows(filepath)

Expand Down
24 changes: 8 additions & 16 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib import parse

from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE
from litdata.streaming.cache import Dir

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import (
ProjectIdDatasetsBody,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

try:
import boto3
import botocore
import boto3
import botocore

_BOTO3_AVAILABLE = True
except Exception:
_BOTO3_AVAILABLE = False
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.streaming.cache import Dir


def _create_dataset(
Expand Down Expand Up @@ -67,6 +55,10 @@ def _create_dataset(
if not storage_dir:
raise ValueError("The storage_dir should be defined.")

from lightning_cloud.openapi import ProjectIdDatasetsBody
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient

client = LightningClient(retry=False)

try:
Expand Down
11 changes: 5 additions & 6 deletions src/litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from time import time
from typing import Any, Optional

from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO
import boto3
import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher

if _BOTO3_AVAILABLE:
import boto3
import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher
from litdata.constants import _IS_IN_STUDIO


class S3Client:
Expand Down
9 changes: 5 additions & 4 deletions src/litdata/streaming/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

TCompressor = TypeVar("TCompressor", bound="Compressor")

if _ZSTD_AVAILABLE:
import zstd


class Compressor(ABC):
"""Base class for compression algorithm."""
Expand All @@ -45,7 +42,7 @@ class ZSTDCompressor(Compressor):
def __init__(self, level: int) -> None:
super().__init__()
if not _ZSTD_AVAILABLE:
raise ModuleNotFoundError()
raise ModuleNotFoundError(str(_ZSTD_AVAILABLE))
self.level = level
self.extension = "zstd"

Expand All @@ -54,9 +51,13 @@ def name(self) -> str:
return f"{self.extension}:{self.level}"

def compress(self, data: bytes) -> bytes:
import zstd

return zstd.compress(data, self.level)

def decompress(self, data: bytes) -> bytes:
import zstd

return zstd.decompress(data)

@classmethod
Expand Down
12 changes: 5 additions & 7 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@
from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
from litdata.streaming.client import S3Client

if _GOOGLE_STORAGE_AVAILABLE:
from google.cloud import storage
else:

class storage: # type: ignore
Client = None


class Downloader(ABC):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
Expand Down Expand Up @@ -96,9 +89,14 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:

class GCPDownloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
if not _GOOGLE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))

super().__init__(remote_dir, cache_dir, chunks)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
from google.cloud import storage

obj = parse.urlparse(remote_filepath)

if obj.scheme != "gs":
Expand Down
39 changes: 15 additions & 24 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,16 @@
from dataclasses import dataclass
from pathlib import Path
from time import sleep
from typing import Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from urllib import parse

from litdata.constants import _LIGHTNING_CLOUD_AVAILABLE
import boto3
import botocore

if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.rest_client import LightningClient

try:
import boto3
import botocore

_BOTO3_AVAILABLE = True
except Exception:
_BOTO3_AVAILABLE = False


try:
from lightning_sdk import Machine, Studio

_LIGHTNING_SDK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):

class Machine: # type: ignore
pass
from litdata.constants import _LIGHTNING_SDK_AVAILABLE

_LIGHTNING_SDK_AVAILABLE = False
if TYPE_CHECKING:
from lightning_sdk import Machine


@dataclass
Expand Down Expand Up @@ -115,6 +98,8 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa


def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir:
from lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

# Get the ids from env variables
Expand Down Expand Up @@ -154,6 +139,8 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option


def _resolve_s3_connections(dir_path: str) -> Dir:
from lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

# Get the ids from env variables
Expand All @@ -174,6 +161,8 @@ def _resolve_s3_connections(dir_path: str) -> Dir:


def _resolve_datasets(dir_path: str) -> Dir:
from lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

# Get the ids from env variables
Expand Down Expand Up @@ -354,14 +343,16 @@ def _resolve_time_template(path: str) -> str:
def _execute(
name: str,
num_nodes: int,
machine: Optional[Machine] = None,
machine: Optional["Machine"] = None,
command: Optional[str] = None,
) -> None:
"""Remotely execute the current operator."""

if not _LIGHTNING_SDK_AVAILABLE:
raise ModuleNotFoundError("The `lightning_sdk` is required.")

from lightning_sdk import Studio

lightning_skip_install = os.getenv("LIGHTNING_SKIP_INSTALL", "")
if lightning_skip_install:
lightning_skip_install = f" LIGHTNING_SKIP_INSTALL={lightning_skip_install} "
Expand Down
Loading
Loading