From df8dcd10bdae27a5b32a7b16a49989508d26c66a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 12 Jul 2024 09:51:48 +0200 Subject: [PATCH] Refactor optional imports (#221) --- src/litdata/constants.py | 2 +- src/litdata/processing/data_processor.py | 18 +++------ src/litdata/processing/functions.py | 12 ++---- src/litdata/processing/readers.py | 11 ++---- src/litdata/processing/utilities.py | 24 ++++-------- src/litdata/streaming/client.py | 11 +++--- src/litdata/streaming/compression.py | 9 +++-- src/litdata/streaming/downloader.py | 12 +++--- src/litdata/streaming/resolver.py | 39 ++++++++----------- src/litdata/streaming/serializers.py | 48 +++++++++++------------- src/litdata/utilities/format.py | 15 ++++++++ tests/conftest.py | 35 +++++++++++++++++ tests/streaming/test_downloader.py | 8 ++-- tests/streaming/test_resolver.py | 32 ++++++++-------- 14 files changed, 142 insertions(+), 134 deletions(-) diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 8befa208..6caa7520 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -33,7 +33,7 @@ _ZSTD_AVAILABLE = RequirementCache("zstd") _GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage") _TQDM_AVAILABLE = RequirementCache("tqdm") - +_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 53cf3350..f7d2ce89 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -30,16 +30,16 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse +import boto3 +import botocore import numpy as np import torch from litdata.constants import ( - _BOTO3_AVAILABLE, _DEFAULT_FAST_DEV_RUN_ITEMS, _ENABLE_STATUS, _INDEX_FILENAME, _IS_IN_STUDIO, - _LIGHTNING_CLOUD_AVAILABLE, _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader @@ -54,16 +54,6 @@ from litdata.utilities.dataset_utilities import load_index_file from litdata.utilities.packing import _pack_greedily -if _TQDM_AVAILABLE: - from tqdm.auto import tqdm as _tqdm - -if _LIGHTNING_CLOUD_AVAILABLE: - from lightning_cloud.openapi import V1DatasetType - -if _BOTO3_AVAILABLE: - import boto3 - import botocore - logger = logging.Logger(__name__) @@ -1043,6 +1033,8 @@ def run(self, data_recipe: DataRecipe) -> None: current_total = 0 if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm + pbar = _tqdm( desc="Progress", total=num_items, @@ -1097,6 +1089,8 @@ def run(self, data_recipe: DataRecipe) -> None: result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir) if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO: + from lightning_cloud.openapi import V1DatasetType + _create_dataset( input_dir=self.input_dir.path, storage_dir=self.output_dir.path, diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index adb5cdeb..2734d5f0 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -27,7 +27,7 @@ import torch -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -46,13 +46,7 @@ _resolve_dir, ) from litdata.utilities._pytree import tree_flatten - -if _TQDM_AVAILABLE: - from tqdm.auto import tqdm as _tqdm -else: - - def _tqdm(iterator: Any) -> Any: - yield from iterator +from litdata.utilities.format import _get_tqdm_iterator_if_available def _is_remote_file(path: str) -> bool: @@ -563,6 +557,8 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore + _tqdm = _get_tqdm_iterator_if_available() + for copy_info in _tqdm(copy_infos): _apply_copy(copy_info, resolved_output_dir) diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index eb3c1511..f67d0f88 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -16,19 +16,12 @@ from abc import ABC, abstractmethod from typing import Any, List -from litdata.constants import _TQDM_AVAILABLE from litdata.imports import RequirementCache from litdata.streaming.dataloader import StreamingDataLoader +from litdata.utilities.format import _get_tqdm_iterator_if_available _PYARROW_AVAILABLE = RequirementCache("pyarrow") -if _TQDM_AVAILABLE: - from tqdm.auto import tqdm as _tqdm -else: - - def _tqdm(iterator: Any) -> Any: - yield from iterator - class BaseReader(ABC): def get_num_nodes(self) -> int: @@ -92,6 +85,8 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]: cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}") os.makedirs(cache_folder, exist_ok=True) + _tqdm = _get_tqdm_iterator_if_available() + for filepath in filepaths: num_rows = self._get_num_rows(filepath) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 8312ea39..939c9e66 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -21,23 +21,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE -from litdata.streaming.cache import Dir - -if _LIGHTNING_CLOUD_AVAILABLE: - from lightning_cloud.openapi import ( - ProjectIdDatasetsBody, - ) - from lightning_cloud.openapi.rest import ApiException - from lightning_cloud.rest_client import LightningClient - -try: - import boto3 - import botocore +import boto3 +import botocore - _BOTO3_AVAILABLE = True -except Exception: - _BOTO3_AVAILABLE = False +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.streaming.cache import Dir def _create_dataset( @@ -67,6 +55,10 @@ def _create_dataset( if not storage_dir: raise ValueError("The storage_dir should be defined.") + from lightning_cloud.openapi import ProjectIdDatasetsBody + from lightning_cloud.openapi.rest import ApiException + from lightning_cloud.rest_client import LightningClient + client = LightningClient(retry=False) try: diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index 354e82a5..6185e189 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -15,13 +15,12 @@ from time import time from typing import Any, Optional -from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO +import boto3 +import botocore +from botocore.credentials import InstanceMetadataProvider +from botocore.utils import InstanceMetadataFetcher -if _BOTO3_AVAILABLE: - import boto3 - import botocore - from botocore.credentials import InstanceMetadataProvider - from botocore.utils import InstanceMetadataFetcher +from litdata.constants import _IS_IN_STUDIO class S3Client: diff --git a/src/litdata/streaming/compression.py b/src/litdata/streaming/compression.py index 21b30cf8..248f0781 100644 --- a/src/litdata/streaming/compression.py +++ b/src/litdata/streaming/compression.py @@ -18,9 +18,6 @@ TCompressor = TypeVar("TCompressor", bound="Compressor") -if _ZSTD_AVAILABLE: - import zstd - class Compressor(ABC): """Base class for compression algorithm.""" @@ -45,7 +42,7 @@ class ZSTDCompressor(Compressor): def __init__(self, level: int) -> None: super().__init__() if not _ZSTD_AVAILABLE: - raise ModuleNotFoundError() + raise ModuleNotFoundError(str(_ZSTD_AVAILABLE)) self.level = level self.extension = "zstd" @@ -54,9 +51,13 @@ def name(self) -> str: return f"{self.extension}:{self.level}" def compress(self, data: bytes) -> bytes: + import zstd + return zstd.compress(data, self.level) def decompress(self, data: bytes) -> bytes: + import zstd + return zstd.decompress(data) @classmethod diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index da5c7a9e..eba85c5e 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -23,13 +23,6 @@ from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME from litdata.streaming.client import S3Client -if _GOOGLE_STORAGE_AVAILABLE: - from google.cloud import storage -else: - - class storage: # type: ignore - Client = None - class Downloader(ABC): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): @@ -96,9 +89,14 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: class GCPDownloader(Downloader): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): + if not _GOOGLE_STORAGE_AVAILABLE: + raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) + super().__init__(remote_dir, cache_dir, chunks) def download_file(self, remote_filepath: str, local_filepath: str) -> None: + from google.cloud import storage + obj = parse.urlparse(remote_filepath) if obj.scheme != "gs": diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 633ebc05..a10ff52a 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,33 +20,16 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from urllib import parse -from litdata.constants import _LIGHTNING_CLOUD_AVAILABLE +import boto3 +import botocore -if _LIGHTNING_CLOUD_AVAILABLE: - from lightning_cloud.rest_client import LightningClient - -try: - import boto3 - import botocore - - _BOTO3_AVAILABLE = True -except Exception: - _BOTO3_AVAILABLE = False - - -try: - from lightning_sdk import Machine, Studio - - _LIGHTNING_SDK_AVAILABLE = True -except (ImportError, ModuleNotFoundError): - - class Machine: # type: ignore - pass +from litdata.constants import _LIGHTNING_SDK_AVAILABLE - _LIGHTNING_SDK_AVAILABLE = False +if TYPE_CHECKING: + from lightning_sdk import Machine @dataclass @@ -115,6 +98,8 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: + from lightning_cloud.rest_client import LightningClient + client = LightningClient(max_tries=2) # Get the ids from env variables @@ -154,6 +139,8 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option def _resolve_s3_connections(dir_path: str) -> Dir: + from lightning_cloud.rest_client import LightningClient + client = LightningClient(max_tries=2) # Get the ids from env variables @@ -174,6 +161,8 @@ def _resolve_s3_connections(dir_path: str) -> Dir: def _resolve_datasets(dir_path: str) -> Dir: + from lightning_cloud.rest_client import LightningClient + client = LightningClient(max_tries=2) # Get the ids from env variables @@ -354,7 +343,7 @@ def _resolve_time_template(path: str) -> str: def _execute( name: str, num_nodes: int, - machine: Optional[Machine] = None, + machine: Optional["Machine"] = None, command: Optional[str] = None, ) -> None: """Remotely execute the current operator.""" @@ -362,6 +351,8 @@ def _execute( if not _LIGHTNING_SDK_AVAILABLE: raise ModuleNotFoundError("The `lightning_sdk` is required.") + from lightning_sdk import Studio + lightning_skip_install = os.getenv("LIGHTNING_SKIP_INSTALL", "") if lightning_skip_install: lightning_skip_install = f" LIGHTNING_SKIP_INSTALL={lightning_skip_install} " diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 87928fa5..58ab056a 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -26,35 +26,13 @@ from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING from litdata.imports import RequirementCache +if TYPE_CHECKING: + from PIL.JpegImagePlugin import JpegImageFile + _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") -if _PIL_AVAILABLE: - from PIL import Image - from PIL.GifImagePlugin import GifImageFile - from PIL.JpegImagePlugin import JpegImageFile - from PIL.PngImagePlugin import PngImageFile - from PIL.WebPImagePlugin import WebPImageFile -else: - - class Image: # type: ignore - Image = None - - class JpegImageFile: # type: ignore - pass - - class PngImageFile: # type: ignore - pass - - class WebPImageFile: # type: ignore - pass - - -if _TORCH_VISION_AVAILABLE: - from torchvision.io import decode_jpeg - from torchvision.transforms.functional import pil_to_tensor - class Serializer(ABC): """The base interface for any serializers. @@ -91,6 +69,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: @classmethod def deserialize(cls, data: bytes) -> Any: + from PIL import Image + idx = 3 * 4 width, height, mode_size = np.frombuffer(data[:idx], np.uint32) idx2 = idx + mode_size @@ -100,6 +80,9 @@ def deserialize(cls, data: bytes) -> Any: return Image.frombytes(mode, size, raw) # pyright: ignore def can_serialize(self, item: Any) -> bool: + from PIL import Image + from PIL.JpegImagePlugin import JpegImageFile + return bool(_PIL_AVAILABLE) and isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) @@ -107,6 +90,12 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + from PIL import Image + from PIL.GifImagePlugin import GifImageFile + from PIL.JpegImagePlugin import JpegImageFile + from PIL.PngImagePlugin import PngImageFile + from PIL.WebPImagePlugin import WebPImageFile + if isinstance(item, JpegImageFile): if not hasattr(item, "filename"): raise ValueError( @@ -130,8 +119,11 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.") - def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: + def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: if _TORCH_VISION_AVAILABLE: + from torchvision.io import decode_jpeg + from torchvision.transforms.functional import pil_to_tensor + array = torch.frombuffer(data, dtype=torch.uint8) try: return decode_jpeg(array) @@ -145,6 +137,8 @@ def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: return img def can_serialize(self, item: Any) -> bool: + from PIL.JpegImagePlugin import JpegImageFile + return bool(_PIL_AVAILABLE) and isinstance(item, JpegImageFile) diff --git a/src/litdata/utilities/format.py b/src/litdata/utilities/format.py index 946b1063..969a8f46 100644 --- a/src/litdata/utilities/format.py +++ b/src/litdata/utilities/format.py @@ -10,6 +10,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + +from litdata.constants import _TQDM_AVAILABLE _FORMAT_TO_RATIO = { "kb": 1000, @@ -40,3 +43,15 @@ def _human_readable_bytes(num_bytes: float) -> str: return f"{num_bytes:3.1f} {unit}" num_bytes /= 1000.0 return f"{num_bytes:.1f} PB" + + +def _get_tqdm_iterator_if_available() -> Any: + if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm + + return _tqdm + + def _pass_through(iterator: Any) -> Any: + yield from iterator + + return _pass_through diff --git a/tests/conftest.py b/tests/conftest.py index 538d0bcb..17e47ed8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,7 @@ +import sys +from types import ModuleType +from unittest.mock import Mock + import pytest import torch.distributed @@ -30,3 +34,34 @@ def mosaic_mds_index_data(): ], "version": 2, } + + +@pytest.fixture() +def google_mock(monkeypatch): + google = ModuleType("google") + monkeypatch.setitem(sys.modules, "google", google) + google_cloud = ModuleType("cloud") + monkeypatch.setitem(sys.modules, "google.cloud", google_cloud) + google_cloud_storage = ModuleType("storage") + monkeypatch.setitem(sys.modules, "google.cloud.storage", google_cloud_storage) + google.cloud = google_cloud + google.cloud.storage = google_cloud_storage + return google + + +@pytest.fixture() +def lightning_cloud_mock(monkeypatch): + lightning_cloud = ModuleType("lightning_cloud") + monkeypatch.setitem(sys.modules, "lightning_cloud", lightning_cloud) + rest_client = ModuleType("rest_client") + monkeypatch.setitem(sys.modules, "lightning_cloud.rest_client", rest_client) + lightning_cloud.rest_client = rest_client + rest_client.LightningClient = Mock() + return lightning_cloud + + +@pytest.fixture() +def lightning_sdk_mock(monkeypatch): + lightning_sdk = ModuleType("lightning_sdk") + monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk) + return lightning_sdk diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 9dfcce23..1c2d34a3 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -1,4 +1,5 @@ import os +from unittest import mock from unittest.mock import MagicMock from litdata.streaming.downloader import ( @@ -19,9 +20,8 @@ def test_s3_downloader_fast(tmpdir, monkeypatch): popen_mock.wait.assert_called() -def test_gcp_downloader(tmpdir, monkeypatch): - from litdata.streaming.downloader import storage - +@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True) +def test_gcp_downloader(tmpdir, monkeypatch, google_mock): # Create mock objects mock_client = MagicMock() mock_bucket = MagicMock() @@ -29,7 +29,7 @@ def test_gcp_downloader(tmpdir, monkeypatch): mock_blob.download_to_filename = MagicMock() # Patch the storage client to return the mock client - monkeypatch.setattr(storage, "Client", MagicMock(return_value=mock_client)) + google_mock.cloud.storage.Client = MagicMock(return_value=mock_client) # Configure the mock client to return the mock bucket and blob mock_client.bucket = MagicMock(return_value=mock_bucket) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 1d48d3d8..2c962454 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -19,7 +19,7 @@ @pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") -def test_src_resolver_s3_connections(monkeypatch): +def test_src_resolver_s3_connections(monkeypatch, lightning_cloud_mock): auth = login.Auth() auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") @@ -35,7 +35,7 @@ def test_src_resolver_s3_connections(monkeypatch): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock assert resolver._resolve_dir("/teamspace/s3_connections/imagenet").url == "s3://imagenet-bucket" assert resolver._resolve_dir("/teamspace/s3_connections/imagenet/train").url == "s3://imagenet-bucket/train" @@ -47,7 +47,7 @@ def test_src_resolver_s3_connections(monkeypatch): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock with pytest.raises(ValueError, match="name `imagenet`"): assert resolver._resolve_dir("/teamspace/s3_connections/imagenet") @@ -56,7 +56,7 @@ def test_src_resolver_s3_connections(monkeypatch): @pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") -def test_src_resolver_studios(monkeypatch): +def test_src_resolver_studios(monkeypatch, lightning_cloud_mock): auth = login.Auth() auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") @@ -85,7 +85,7 @@ def test_src_resolver_studios(monkeypatch): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock expected = "s3://my_bucket/projects/project_id/cloudspaces/other_studio_id/code/content" assert resolver._resolve_dir("/teamspace/studios/other_studio").url == expected @@ -123,7 +123,7 @@ def fn(pattern): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock with pytest.raises(ValueError, match="other_studio`"): resolver._resolve_dir("/teamspace/studios/other_studio") @@ -132,7 +132,7 @@ def fn(pattern): @pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") -def test_src_resolver_datasets(monkeypatch): +def test_src_resolver_datasets(monkeypatch, lightning_cloud_mock): auth = login.Auth() auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") @@ -168,7 +168,7 @@ def test_src_resolver_datasets(monkeypatch): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock expected = "s3://my_bucket/projects/project_id/datasets/imagenet" assert resolver._resolve_dir("/teamspace/datasets/imagenet").url == expected @@ -185,7 +185,7 @@ def test_src_resolver_datasets(monkeypatch): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock with pytest.raises(ValueError, match="cloud_space_id`"): resolver._resolve_dir("/teamspace/datasets/imagenet") @@ -194,7 +194,7 @@ def test_src_resolver_datasets(monkeypatch): @pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") -def test_dst_resolver_dataset_path(monkeypatch): +def test_dst_resolver_dataset_path(monkeypatch, lightning_cloud_mock): auth = login.Auth() auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") @@ -218,7 +218,7 @@ def test_dst_resolver_dataset_path(monkeypatch): client_cls_mock = mock.MagicMock() client_cls_mock.return_value = client_mock - resolver.LightningClient = client_cls_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock boto3 = mock.MagicMock() client_s3_mock = mock.MagicMock() @@ -240,7 +240,7 @@ def test_dst_resolver_dataset_path(monkeypatch): @pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") @pytest.mark.parametrize("phase", ["LIGHTNINGAPP_INSTANCE_STATE_STOPPED", "LIGHTNINGAPP_INSTANCE_STATE_COMPLETED"]) -def test_execute(phase, monkeypatch): +def test_execute(phase, monkeypatch, lightning_sdk_mock): studio = mock.MagicMock() studio._studio.id = "studio_id" studio._teamspace.id = "teamspace_id" @@ -256,11 +256,9 @@ def test_execute(phase, monkeypatch): job.status.phase = phase studio._studio_api.create_data_prep_machine_job.return_value = job studio._studio_api._client.lightningapp_instance_service_get_lightningapp_instance.return_value = job - if not hasattr(resolver, "Studio"): - resolver.Studio = mock.MagicMock(return_value=studio) - resolver._LIGHTNING_SDK_AVAILABLE = True - else: - monkeypatch.setattr(resolver, "Studio", mock.MagicMock(return_value=studio)) + + monkeypatch.setattr(resolver, "_LIGHTNING_SDK_AVAILABLE", True) + lightning_sdk_mock.Studio = mock.MagicMock(return_value=studio) called = False