From d7682b93213077e3d5c45ac03c0d06e2bf85f39b Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 19 Sep 2024 17:49:36 +0200 Subject: [PATCH] Revert "Feat: Using fsspec to download files (#348)" This reverts commit 719bae27b6c7e12c19f1c12a121ce5a1e95bf51a. --- README.md | 31 +-- requirements.txt | 4 +- requirements/extras.txt | 2 - src/litdata/constants.py | 1 - src/litdata/processing/data_processor.py | 172 ++++++------- src/litdata/processing/functions.py | 74 +++--- src/litdata/processing/utilities.py | 43 +++- src/litdata/streaming/client.py | 70 ++++++ src/litdata/streaming/dataset.py | 6 +- src/litdata/streaming/downloader.py | 296 ++++++++++------------- src/litdata/streaming/resolver.py | 89 ++++--- tests/processing/test_data_processor.py | 33 ++- tests/streaming/test_client.py | 97 ++++++++ tests/streaming/test_dataset.py | 2 +- tests/streaming/test_downloader.py | 67 ++++- tests/streaming/test_resolver.py | 65 +++-- 16 files changed, 611 insertions(+), 441 deletions(-) create mode 100644 src/litdata/streaming/client.py create mode 100644 tests/streaming/test_client.py diff --git a/README.md b/README.md index 78ab67c6..e03d0c13 100644 --- a/README.md +++ b/README.md @@ -217,8 +217,9 @@ Additionally, you can inject client connection settings for [S3](https://boto3.a from litdata import StreamingDataset storage_options = { - "key": "your_access_key_id", - "secret": "your_secret_access_key", + "endpoint_url": "your_endpoint_url", + "aws_access_key_id": "your_access_key_id", + "aws_secret_access_key": "your_secret_access_key", } dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) @@ -263,7 +264,7 @@ for batch in val_dataloader:   -The StreamingDataset supports reading optimized datasets from common cloud providers. +The StreamingDataset supports reading optimized datasets from common cloud providers. ```python import os @@ -271,39 +272,25 @@ import litdata as ld # Read data from AWS S3 aws_storage_options={ - "key": os.environ['AWS_ACCESS_KEY_ID'], - "secret": os.environ['AWS_SECRET_ACCESS_KEY'], + "AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'], + "AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'], } dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) # Read data from GCS gcp_storage_options={ - "token": { - # dumped from cat ~/.config/gcloud/application_default_credentials.json - "account": "", - "client_id": "your_client_id", - "client_secret": "your_client_secret", - "quota_project_id": "your_quota_project_id", - "refresh_token": "your_refresh_token", - "type": "authorized_user", - "universe_domain": "googleapis.com", - } + "project": os.environ['PROJECT_ID'], } dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options) # Read data from Azure azure_storage_options={ - "account_name": "azure_account_name", - "account_key": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] + "account_url": f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net", + "credential": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] } dataset = ld.StreamingDataset("azure://my-bucket/my-data", storage_options=azure_storage_options) ``` -- For more details on which storage options are supported, please refer to: - - [AWS S3 storage options](https://github.com/fsspec/s3fs/blob/main/s3fs/core.py#L176) - - [GCS storage options](https://github.com/fsspec/gcsfs/blob/main/gcsfs/core.py#L154) - - [Azure storage options](https://github.com/fsspec/adlfs/blob/main/adlfs/spec.py#L124) -
diff --git a/requirements.txt b/requirements.txt index ec443722..06a629a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,5 @@ torch lightning-utilities filelock numpy -# boto3 +boto3 requests -fsspec -fsspec[s3] # aws s3 diff --git a/requirements/extras.txt b/requirements/extras.txt index 33d42446..385e2e81 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -5,5 +5,3 @@ pyarrow tqdm lightning-sdk ==0.1.17 # Must be pinned to ensure compatibility google-cloud-storage -fsspec[gs] # google cloud storage -fsspec[abfs] # azure blob diff --git a/src/litdata/constants.py b/src/litdata/constants.py index efe2e248..a6a714c7 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -85,4 +85,3 @@ _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" _IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None)) _ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0"))) -_SUPPORTED_CLOUD_PROVIDERS = ["s3", "gs", "azure", "abfs"] diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index fae806b5..f1af9afa 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -32,6 +32,8 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse +import boto3 +import botocore import numpy as np import torch @@ -40,21 +42,14 @@ _ENABLE_STATUS, _INDEX_FILENAME, _IS_IN_STUDIO, - _SUPPORTED_CLOUD_PROVIDERS, _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader -from litdata.processing.utilities import _create_dataset, remove_uuid_from_filename +from litdata.processing.utilities import _create_dataset, download_directory_from_S3, remove_uuid_from_filename from litdata.streaming import Cache from litdata.streaming.cache import Dir +from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader -from litdata.streaming.downloader import ( - does_file_exist, - download_file_or_directory, - get_cloud_provider, - remove_file_or_directory, - upload_file_or_directory, -) from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import _resolve_dir from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads @@ -101,22 +96,14 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist( - remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5, storage_options: Optional[Dict] = {} -) -> Any: - """This function check if a file exists on the remote storage. - - If not, it waits for a while and tries again. - - """ - cloud_provider = get_cloud_provider(remote_filepath) +def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any: + """This function check.""" while True: try: - return does_file_exist(remote_filepath, cloud_provider, storage_options=storage_options) - except Exception as e: - if wait_for_count > 0: + return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/")) + except botocore.exceptions.ClientError as e: + if "the HeadObject operation: Not Found" in str(e): sleep(sleep_time) - wait_for_count -= 1 else: raise e @@ -131,10 +118,10 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: return -def _download_data_target( - input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Optional[Dict] = {} -) -> None: - """This function is used to download data from a remote directory to a cache directory to optimise reading.""" +def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: + """Download data from a remote directory to a cache directory to optimise reading.""" + s3 = S3Client() + while True: # 2. Fetch from the queue r: Optional[Tuple[int, List[str]]] = queue_in.get() @@ -169,11 +156,13 @@ def _download_data_target( obj = parse.urlparse(path) - if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + if obj.scheme == "s3": dirpath = os.path.dirname(local_path) os.makedirs(dirpath, exist_ok=True) - download_file_or_directory(path, local_path, storage_options=storage_options) + + with open(local_path, "wb") as f: + s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif os.path.isfile(path): if not path.startswith("/teamspace/studios/this_studio"): @@ -209,13 +198,12 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: os.remove(path) -def _upload_fn( - upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir, storage_options: Optional[Dict] = {} -) -> None: - """This function is used to upload optimised chunks from a local to remote dataset directory.""" +def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: + """Upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) - is_remote = obj.scheme in _SUPPORTED_CLOUD_PROVIDERS + if obj.scheme == "s3": + s3 = S3Client() while True: data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get() @@ -235,7 +223,7 @@ def _upload_fn( if not local_filepath.startswith(cache_dir): local_filepath = os.path.join(cache_dir, local_filepath) - if is_remote: + if obj.scheme == "s3": try: output_filepath = str(obj.path).lstrip("/") @@ -247,8 +235,12 @@ def _upload_fn( output_filepath = os.path.join(output_filepath, local_filepath.replace(tmpdir, "")[1:]) output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints - remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + output_filepath - upload_file_or_directory(local_filepath, remote_filepath, storage_options=storage_options) + + s3.client.upload_file( + local_filepath, + obj.netloc, + output_filepath, + ) except Exception as e: print(e) @@ -425,7 +417,6 @@ def __init__( checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None, checkpoint_next_index: Optional[int] = None, item_loader: Optional[BaseItemLoader] = None, - storage_options: Optional[Dict] = {}, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -460,7 +451,6 @@ def __init__( self.use_checkpoint: bool = use_checkpoint self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info self.checkpoint_next_index: Optional[int] = checkpoint_next_index - self.storage_options = storage_options def run(self) -> None: try: @@ -637,7 +627,6 @@ def _start_downloaders(self) -> None: self.cache_data_dir, to_download_queue, self.ready_to_process_queue, - self.storage_options, ), ) p.start() @@ -677,7 +666,6 @@ def _start_uploaders(self) -> None: self.remove_queue, self.cache_chunks_dir, self.output_dir, - self.storage_options, ), ) p.start() @@ -779,7 +767,6 @@ def __init__( chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, encryption: Optional[Encryption] = None, - storage_options: Optional[Dict] = {}, ): super().__init__() if chunk_size is not None and chunk_bytes is not None: @@ -789,7 +776,6 @@ def __init__( self.chunk_bytes = 1 << 26 if chunk_size is None and chunk_bytes is None else chunk_bytes self.compression = compression self.encryption = encryption - self.storage_options = storage_options @abstractmethod def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @@ -856,12 +842,10 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra else: local_filepath = os.path.join(cache_dir, _INDEX_FILENAME) - if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: - remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" - upload_file_or_directory( - local_filepath, - remote_filepath + os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)), - storage_options=self.storage_options, + if obj.scheme == "s3": + s3 = S3Client() + s3.client.upload_file( + local_filepath, obj.netloc, os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)) ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) @@ -879,13 +863,11 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra assert output_dir_path remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) - if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: - _wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options) - download_file_or_directory( - remote_filepath, - node_index_filepath, - storage_options=self.storage_options, - ) + if obj.scheme == "s3": + obj = parse.urlparse(remote_filepath) + _wait_for_file_to_exist(s3, obj) + with open(node_index_filepath, "wb") as f: + s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(remote_filepath, node_index_filepath) @@ -926,7 +908,6 @@ def __init__( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, - storage_options: Optional[Dict] = {}, ): """Provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -951,7 +932,6 @@ def __init__( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. - storage_options: The storage options used by the cloud provider. """ # spawn doesn't work in IPython @@ -988,7 +968,6 @@ def __init__( self.item_loader = item_loader self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} - self.storage_options = storage_options if self.reader is not None and self.weights is not None: raise ValueError("Either the reader or the weights needs to be defined.") @@ -1140,11 +1119,7 @@ def run(self, data_recipe: DataRecipe) -> None: # Exit early if all the workers are done. # This means there were some kinda of errors. if all(not w.is_alive() for w in self.workers): - try: - error = self.error_queue.get(timeout=0.001) - self._exit_on_error(error) - except Empty: - break + raise RuntimeError("One of the worker has failed") if _TQDM_AVAILABLE: pbar.close() @@ -1211,7 +1186,6 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None, self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None, self.item_loader, - storage_options=self.storage_options, ) worker.start() workers.append(worker) @@ -1263,14 +1237,21 @@ def _cleanup_checkpoints(self) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: - raise ValueError( - f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." - ) - with suppress(FileNotFoundError): - remove_file_or_directory( - os.path.join(self.output_dir.url, ".checkpoints"), storage_options=self.storage_options - ) + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") + + s3 = boto3.client("s3") + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + + # Delete all the files (including the index file in overwrite mode) + bucket_name = obj.netloc + s3 = boto3.resource("s3") + + checkpoint_prefix = os.path.join(prefix, ".checkpoints") + + for obj in s3.Bucket(bucket_name).objects.filter(Prefix=checkpoint_prefix): + s3.Object(bucket_name, obj.key).delete() def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: if not self.use_checkpoint: @@ -1296,20 +1277,24 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: - raise ValueError( - f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." - ) + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") + + # TODO: Add support for all cloud providers + + s3 = S3Client() + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" # write config.json file to temp directory and upload it to s3 with tempfile.TemporaryDirectory() as temp_dir: temp_file_name = os.path.join(temp_dir, "config.json") with open(temp_file_name, "w") as f: json.dump(config, f) - upload_file_or_directory( + s3.client.upload_file( temp_file_name, - os.path.join(self.output_dir.url, ".checkpoints", "config.json"), - storage_options=self.storage_options, + obj.netloc, + os.path.join(prefix, "config.json"), ) except Exception as e: print(e) @@ -1360,25 +1345,26 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: - raise ValueError( - f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." - ) + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") + + # TODO: Add support for all cloud providers + + prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" + + # Delete all the files (including the index file in overwrite mode) + bucket_name = obj.netloc # download all the checkpoint files in tempdir and read them with tempfile.TemporaryDirectory() as temp_dir: - try: - download_file_or_directory( - os.path.join(self.output_dir.url, ".checkpoints/"), temp_dir, storage_options=self.storage_options - ) - except FileNotFoundError: - return - if not os.path.exists(os.path.join(temp_dir, "config.json")): + saved_file_dir = download_directory_from_S3(bucket_name, prefix, temp_dir) + + if not os.path.exists(os.path.join(saved_file_dir, "config.json")): # if the config.json file doesn't exist, we don't have any checkpoint saved return # read the config.json file - with open(os.path.join(temp_dir, "config.json")) as f: + with open(os.path.join(saved_file_dir, "config.json")) as f: config = json.load(f) if config["num_workers"] != self.num_workers: @@ -1392,11 +1378,11 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: checkpoint_file_names = [f"checkpoint-{worker_idx}.json" for worker_idx in range(self.num_workers)] for i, checkpoint_file_name in enumerate(checkpoint_file_names): - if not os.path.exists(os.path.join(temp_dir, checkpoint_file_name)): + if not os.path.exists(os.path.join(saved_file_dir, checkpoint_file_name)): # if the checkpoint file doesn't exist, we don't have any checkpoint saved for this worker continue - with open(os.path.join(temp_dir, checkpoint_file_name)) as f: + with open(os.path.join(saved_file_dir, checkpoint_file_name)) as f: checkpoint = json.load(f) self.checkpoint_chunks_info[i] = checkpoint["chunks"] diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index e83c8c1b..dd62909a 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -27,7 +27,7 @@ import torch -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -36,8 +36,8 @@ optimize_dns_context, read_index_file_content, ) +from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader -from litdata.streaming.downloader import copy_file_or_directory, upload_file_or_directory from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import ( Dir, @@ -53,7 +53,7 @@ def _is_remote_file(path: str) -> bool: obj = parse.urlparse(path) - return obj.scheme in _SUPPORTED_CLOUD_PROVIDERS + return obj.scheme in ["s3", "gcs"] def _get_indexed_paths(data: Any) -> Dict[int, str]: @@ -151,15 +151,8 @@ def __init__( compression: Optional[str], encryption: Optional[Encryption] = None, existing_index: Optional[Dict[str, Any]] = None, - storage_options: Optional[Dict] = {}, ): - super().__init__( - chunk_size=chunk_size, - chunk_bytes=chunk_bytes, - compression=compression, - encryption=encryption, - storage_options=storage_options, - ) + super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption) self._fn = fn self._inputs = inputs self.is_generator = False @@ -206,7 +199,6 @@ def map( error_when_not_empty: bool = False, reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, - storage_options: Optional[Dict] = {}, ) -> None: """Maps a callable over a collection of inputs, possibly in a distributed way. @@ -228,7 +220,6 @@ def map( error_when_not_empty: Whether we should error if the output folder isn't empty. reader: The reader to use when reading the data. By default, it uses the `BaseReader`. batch_size: Group the inputs into batches of batch_size length. - storage_options: The storage options used by the cloud provider. """ if isinstance(inputs, StreamingDataLoader) and batch_size is not None: @@ -267,7 +258,7 @@ def map( ) if error_when_not_empty: - _assert_dir_is_empty(_output_dir, storage_options=storage_options) + _assert_dir_is_empty(_output_dir) if not isinstance(inputs, StreamingDataLoader): input_dir = input_dir or _get_input_dir(inputs) @@ -291,7 +282,6 @@ def map( reorder_files=reorder_files, weights=weights, reader=reader, - storage_options=storage_options, ) with optimize_dns_context(True): return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) @@ -325,7 +315,6 @@ def optimize( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, - storage_options: Optional[Dict] = {}, ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. @@ -360,7 +349,6 @@ def optimize( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. - storage_options: The storage options used by the cloud provider. """ if mode is not None and mode not in ["append", "overwrite"]: @@ -415,9 +403,7 @@ def optimize( "\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file( - _output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options - ) + _assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -433,9 +419,7 @@ def optimize( num_workers = num_workers or _get_default_num_workers() state_dict = {rank: 0 for rank in range(num_workers)} - existing_index_file_content = ( - read_index_file_content(_output_dir, storage_options=storage_options) if mode == "append" else None - ) + existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None if existing_index_file_content is not None: for chunk in existing_index_file_content["chunks"]: @@ -457,7 +441,6 @@ def optimize( use_checkpoint=use_checkpoint, item_loader=item_loader, start_method=start_method, - storage_options=storage_options, ) with optimize_dns_context(True): @@ -470,7 +453,6 @@ def optimize( compression=compression, encryption=encryption, existing_index=existing_index_file_content, - storage_options=storage_options, ) ) return None @@ -539,14 +521,12 @@ class CopyInfo: new_filename: str -def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Optional[Dict] = {}) -> None: - """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized - dataset. +def merge_datasets(input_dirs: List[str], output_dir: str) -> None: + """Enables to merge multiple existing optimized datasets into a single optimized dataset. Args: input_dirs: A list of directories pointing to the existing optimized datasets. output_dir: The directory where the merged dataset would be stored. - storage_options: A dictionary of storage options to be passed to the fsspec library. """ if len(input_dirs) == 0: @@ -561,14 +541,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Opti if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") - input_dirs_file_content = [ - read_index_file_content(input_dir, storage_options=storage_options) for input_dir in resolved_input_dirs - ] + input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] if any(file_content is None for file_content in input_dirs_file_content): raise ValueError("One of the provided input_dir doesn't have an index file.") - output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options=storage_options) + output_dir_file_content = read_index_file_content(resolved_output_dir) if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") @@ -603,12 +581,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Opti _tqdm = _get_tqdm_iterator_if_available() for copy_info in _tqdm(copy_infos): - _apply_copy(copy_info, resolved_output_dir, storage_options=storage_options) + _apply_copy(copy_info, resolved_output_dir) - _save_index(index_json, resolved_output_dir, storage_options=storage_options) + _save_index(index_json, resolved_output_dir) -def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: +def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: if output_dir.url is None and copy_info.input_dir.url is None: assert copy_info.input_dir.path assert output_dir.path @@ -618,15 +596,20 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Optional[ shutil.copyfile(input_filepath, output_filepath) elif output_dir.url and copy_info.input_dir.url: - input_obj = os.path.join(copy_info.input_dir.url, copy_info.old_filename) - output_obj = os.path.join(output_dir.url, copy_info.new_filename) - - copy_file_or_directory(input_obj, output_obj, storage_options=storage_options) + input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename)) + output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename)) + + s3 = S3Client() + s3.client.copy( + {"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")}, + output_obj.netloc, + output_obj.path.lstrip("/"), + ) else: raise NotImplementedError -def _save_index(index_json: Dict, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: +def _save_index(index_json: Dict, output_dir: Dir) -> None: if output_dir.url is None: assert output_dir.path with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: @@ -637,6 +620,11 @@ def _save_index(index_json: Dict, output_dir: Dir, storage_options: Optional[Dic f.flush() - upload_file_or_directory( - f.name, os.path.join(output_dir.url, _INDEX_FILENAME), storage_options=storage_options + obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME)) + + s3 = S3Client() + s3.client.upload_file( + f.name, + obj.netloc, + obj.path.lstrip("/"), ) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index a13e863d..a50b9652 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -21,9 +21,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS +import boto3 +import botocore + +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO from litdata.streaming.cache import Dir -from litdata.streaming.downloader import download_file_or_directory def _create_dataset( @@ -181,7 +183,7 @@ def _get_work_dir() -> str: return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" -def read_index_file_content(output_dir: Dir, storage_options: Optional[Dict] = {}) -> Optional[Dict[str, Any]]: +def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: """Read the index file content.""" if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir should be a Dir object.") @@ -199,26 +201,27 @@ def read_index_file_content(output_dir: Dir, storage_options: Optional[Dict] = { # download the index file from s3, and read it obj = parse.urlparse(output_dir.url) - if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: - raise ValueError( - f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.path}." - ) + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.") + + # TODO: Add support for all cloud providers + s3 = boto3.client("s3") + + prefix = obj.path.lstrip("/").rstrip("/") + "/" # Check the index file exists try: # Create a temporary file with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file: temp_file_name = temp_file.name - download_file_or_directory( - os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name, storage_options=storage_options - ) + s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name) # Read data from the temporary file with open(temp_file_name) as temp_file: data = json.load(temp_file) # Delete the temporary file os.remove(temp_file_name) return data - except Exception as _e: + except botocore.exceptions.ClientError: return None @@ -253,3 +256,21 @@ def remove_uuid_from_filename(filepath: str) -> str: # uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character return filepath[:-38] + ".json" + + +def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str: + s3_resource = boto3.resource("s3") + bucket = s3_resource.Bucket(bucket_name) + + saved_file_dir = "." + + for obj in bucket.objects.filter(Prefix=remote_directory_name): + local_filename = os.path.join(local_directory_name, obj.key) + + if not os.path.exists(os.path.dirname(local_filename)): + os.makedirs(os.path.dirname(local_filename)) + with open(local_filename, "wb") as f: + s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f) + saved_file_dir = os.path.dirname(local_filename) + + return saved_file_dir diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py new file mode 100644 index 00000000..d24803c3 --- /dev/null +++ b/src/litdata/streaming/client.py @@ -0,0 +1,70 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from time import time +from typing import Any, Dict, Optional + +import boto3 +import botocore +from botocore.credentials import InstanceMetadataProvider +from botocore.utils import InstanceMetadataFetcher + +from litdata.constants import _IS_IN_STUDIO + + +class S3Client: + # TODO: Generalize to support more cloud providers. + + def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: + self._refetch_interval = refetch_interval + self._last_time: Optional[float] = None + self._client: Optional[Any] = None + self._storage_options: dict = storage_options or {} + + def _create_client(self) -> None: + has_shared_credentials_file = ( + os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" + ) + + if has_shared_credentials_file or not _IS_IN_STUDIO: + self._client = boto3.client( + "s3", + **{ + "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), + **self._storage_options, + }, + ) + else: + provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) + credentials = provider.load() + self._client = boto3.client( + "s3", + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, + config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), + ) + + @property + def client(self) -> Any: + if self._client is None: + self._create_client() + self._last_time = time() + + # Re-generate credentials for EC2 + if self._last_time is None or (time() - self._last_time) > self._refetch_interval: + self._create_client() + self._last_time = time() + + return self._client diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index a6f70bf5..ea82ce7a 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -155,8 +155,7 @@ def set_epoch(self, current_epoch: int) -> None: def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( - input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, - storage_options=self.storage_options, + input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url ) if cache_path is not None: self.input_dir.path = cache_path @@ -439,8 +438,7 @@ def _validate_state_dict(self) -> None: # In this case, validate the cache folder is the same. if _should_replace_path(state["input_dir_path"]): cache_path = _try_create_cache_dir( - input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"], - storage_options=self.storage_options, + input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"] ) if cache_path != self.input_dir.path: raise ValueError( diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 463ab576..41e4a6a9 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -16,32 +16,24 @@ import shutil import subprocess from abc import ABC -from typing import Any, Dict, List, Optional, Union +from contextlib import suppress +from typing import Any, Dict, List, Optional from urllib import parse -import fsspec from filelock import FileLock, Timeout -from litdata.constants import _INDEX_FILENAME - -# from litdata.streaming.client import S3Client - -_USE_S5CMD_FOR_S3 = True +from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME +from litdata.streaming.client import S3Client class Downloader(ABC): def __init__( - self, - cloud_provider: str, - remote_dir: str, - cache_dir: str, - chunks: List[Dict[str, Any]], - storage_options: Optional[Dict] = {}, + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ): self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks - self.fs = fsspec.filesystem(cloud_provider, **storage_options) + self._storage_options = storage_options or {} def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] @@ -53,195 +45,157 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: pass -class LocalDownloader(Downloader): - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - if not os.path.exists(remote_filepath): - raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") - - try: - with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0): - if remote_filepath != local_filepath and not os.path.exists(local_filepath): - # make an atomic operation to be safe - temp_file_path = local_filepath + ".tmp" - shutil.copy(remote_filepath, temp_file_path) - os.rename(temp_file_path, local_filepath) - with contextlib.suppress(Exception): - os.remove(local_filepath + ".lock") - except Timeout: - pass +class S3Downloader(Downloader): + def __init__( + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + ): + super().__init__(remote_dir, cache_dir, chunks, storage_options) + self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 + if not self._s5cmd_available: + self._client = S3Client(storage_options=self._storage_options) -class LocalDownloaderWithCache(LocalDownloader): def download_file(self, remote_filepath: str, local_filepath: str) -> None: - remote_filepath = remote_filepath.replace("local:", "") - super().download_file(remote_filepath, local_filepath) - - -def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: - _s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 + obj = parse.urlparse(remote_filepath) - if _s5cmd_available is False: - raise ModuleNotFoundError(str(_s5cmd_available)) + if obj.scheme != "s3": + raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - obj = parse.urlparse(remote_filepath) + if os.path.exists(local_filepath): + return - if obj.scheme != "s3": - raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") + with suppress(Timeout), FileLock( + local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0 + ): + if self._s5cmd_available: + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", + shell=True, + stdout=subprocess.PIPE, + ) + proc.wait() + else: + from boto3.s3.transfer import TransferConfig + + extra_args: Dict[str, Any] = {} + + # try: + # with FileLock(local_filepath + ".lock", timeout=1): + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) + + +class GCPDownloader(Downloader): + def __init__( + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + ): + if not _GOOGLE_STORAGE_AVAILABLE: + raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) - if os.path.exists(local_filepath): - return + super().__init__(remote_dir, cache_dir, chunks, storage_options) - try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - proc = subprocess.Popen( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - ) - proc.wait() - except Timeout: - # another process is responsible to download that file, continue - pass + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + from google.cloud import storage + obj = parse.urlparse(remote_filepath) -_DOWNLOADERS = { - "s3://": "s3", - "gs://": "gs", - "azure://": "abfs", - "abfs://": "abfs", - "local:": "file", - "": "file", -} + if obj.scheme != "gs": + raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote_filepath}") -_DEFAULT_STORAGE_OPTIONS = { - "s3": {"config_kwargs": {"retries": {"max_attempts": 1000, "mode": "adaptive"}}}, -} + if os.path.exists(local_filepath): + return + with suppress(Timeout), FileLock( + local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0 + ): + bucket_name = obj.netloc + key = obj.path + # Remove the leading "/": + if key[0] == "/": + key = key[1:] -def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict: - if storage_options is None: - storage_options = {} - if cloud_provider in _DEFAULT_STORAGE_OPTIONS: - return {**_DEFAULT_STORAGE_OPTIONS[cloud_provider], **storage_options} - return storage_options + client = storage.Client(**self._storage_options) + bucket = client.bucket(bucket_name) + blob = bucket.blob(key) + blob.download_to_filename(local_filepath) -class FsspecDownloader(Downloader): +class AzureDownloader(Downloader): def __init__( - self, - cloud_provider: str, - remote_dir: str, - cache_dir: str, - chunks: List[Dict[str, Any]], - storage_options: Optional[Dict] = {}, + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ): - remote_dir = remote_dir.replace("local:", "") - self.is_local = False - storage_options = get_complete_storage_options(cloud_provider, storage_options) - super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) - self.cloud_provider = cloud_provider - self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 + if not _AZURE_STORAGE_AVAILABLE: + raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) + + super().__init__(remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: - if os.path.exists(local_filepath) or remote_filepath == local_filepath: - return - if self.use_s5cmd and _USE_S5CMD_FOR_S3: - download_s3_file_via_s5cmd(remote_filepath, local_filepath) + from azure.storage.blob import BlobServiceClient + + obj = parse.urlparse(remote_filepath) + + if obj.scheme != "azure": + raise ValueError( + f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}" + ) + + if os.path.exists(local_filepath): return - try: - with FileLock(local_filepath + ".lock", timeout=3): - self.fs.get(remote_filepath, local_filepath, recursive=True) - # remove the lock file - if os.path.exists(local_filepath + ".lock"): - os.remove(local_filepath + ".lock") - except Timeout: - # another process is responsible to download that file, continue - pass - - -def does_file_exist( - remote_filepath: str, cloud_provider: Union[str, None] = None, storage_options: Optional[Dict] = {} -) -> bool: - if cloud_provider is None: - cloud_provider = get_cloud_provider(remote_filepath) - storage_options = get_complete_storage_options(cloud_provider, storage_options) - fs = fsspec.filesystem(cloud_provider, **storage_options) - return fs.exists(remote_filepath) - - -def list_directory( - remote_directory: str, - detail: bool = False, - cloud_provider: Optional[str] = None, - storage_options: Optional[Dict] = {}, -) -> List[str]: - """Returns a list of filenames in a remote directory.""" - if cloud_provider is None: - cloud_provider = get_cloud_provider(remote_directory) - storage_options = get_complete_storage_options(cloud_provider, storage_options) - fs = fsspec.filesystem(cloud_provider, **storage_options) - return fs.ls(remote_directory, detail=detail) # just return the filenames - - -def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """Download a file from the remote cloud storage.""" - fs_cloud_provider = get_cloud_provider(remote_filepath) - use_s5cmd = fs_cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 - if use_s5cmd and _USE_S5CMD_FOR_S3: - download_s3_file_via_s5cmd(remote_filepath, local_filepath) - return - try: - with FileLock(local_filepath + ".lock", timeout=3): - storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) - fs = fsspec.filesystem(fs_cloud_provider, **storage_options) - fs.get(remote_filepath, local_filepath, recursive=True) - except Timeout: - # another process is responsible to download that file, continue - pass + with suppress(Timeout), FileLock( + local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0 + ): + service = BlobServiceClient(**self._storage_options) + blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/")) + with open(local_filepath, "wb") as download_file: + blob_data = blob_client.download_blob() + blob_data.readinto(download_file) -def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """Upload a file to the remote cloud storage.""" - try: - with FileLock(local_filepath + ".lock", timeout=3): - fs_cloud_provider = get_cloud_provider(remote_filepath) - storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) - fs = fsspec.filesystem(fs_cloud_provider, **storage_options) - fs.put(local_filepath, remote_filepath, recursive=True) - except Timeout: - # another process is responsible to upload that file, continue - pass +class LocalDownloader(Downloader): + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + if not os.path.exists(remote_filepath): + raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") -def copy_file_or_directory( - remote_filepath_src: str, remote_filepath_tg: str, storage_options: Optional[Dict] = {} -) -> None: - """Copy a file from src to target on the remote cloud storage.""" - fs_cloud_provider = get_cloud_provider(remote_filepath_src) - storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) - fs = fsspec.filesystem(fs_cloud_provider, **storage_options) - fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) + with suppress(Timeout), FileLock( + local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0 + ): + if remote_filepath == local_filepath or os.path.exists(local_filepath): + return + # make an atomic operation to be safe + temp_file_path = local_filepath + ".tmp" + shutil.copy(remote_filepath, temp_file_path) + os.rename(temp_file_path, local_filepath) + with contextlib.suppress(Exception): + os.remove(local_filepath + ".lock") -def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """Remove a file from the remote cloud storage.""" - fs_cloud_provider = get_cloud_provider(remote_filepath) - storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) - fs = fsspec.filesystem(fs_cloud_provider, **storage_options) - fs.rm(remote_filepath, recursive=True) +class LocalDownloaderWithCache(LocalDownloader): + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + remote_filepath = remote_filepath.replace("local:", "") + super().download_file(remote_filepath, local_filepath) -def get_cloud_provider(remote_filepath: str) -> str: - for k, fs_cloud_provider in _DOWNLOADERS.items(): - if str(remote_filepath).startswith(k): - return fs_cloud_provider - raise ValueError(f"The provided `remote_filepath` {remote_filepath} doesn't have a downloader associated.") +_DOWNLOADERS = { + "s3://": S3Downloader, + "gs://": GCPDownloader, + "azure://": AzureDownloader, + "local:": LocalDownloaderWithCache, + "": LocalDownloader, +} def get_downloader_cls( remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ) -> Downloader: - for k, fs_cloud_provider in _DOWNLOADERS.items(): + for k, cls in _DOWNLOADERS.items(): if str(remote_dir).startswith(k): - return FsspecDownloader(fs_cloud_provider, remote_dir, cache_dir, chunks, storage_options) + return cls(remote_dir, cache_dir, chunks, storage_options) raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index a1781ccc..98ce5fef 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,15 +20,13 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from urllib import parse -from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_CLOUD_PROVIDERS -from litdata.streaming.downloader import ( - does_file_exist, - list_directory, - remove_file_or_directory, -) +import boto3 +import botocore + +from litdata.constants import _LIGHTNING_SDK_AVAILABLE if TYPE_CHECKING: from lightning_sdk import Machine @@ -54,15 +52,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: assert isinstance(dir_path, str) - cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS - dir_scheme = parse.urlparse(dir_path).scheme - if bool(dir_scheme) and dir_scheme not in ["c", "d", "e", "f"]: # prevent windows `c:\\` and `d:\\` - if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): - return Dir(path=None, url=dir_path) - raise ValueError( - f"The provided dir_path `{dir_path}` is not supported.", - f" HINT: Only the following cloud providers are supported: {_SUPPORTED_CLOUD_PROVIDERS}.", - ) + cloud_prefixes = ("s3://", "gs://", "azure://") + if dir_path.startswith(cloud_prefixes): + return Dir(path=None, url=dir_path) if dir_path.startswith("local:"): return Dir(path=None, url=dir_path) @@ -96,11 +88,14 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa if target_id is not None and cloudspace.id == target_id: return True - return bool( + if ( cloudspace.display_name is not None and target_name is not None and cloudspace.display_name.lower() == target_name.lower() - ) + ): + return True + + return False def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: @@ -214,9 +209,7 @@ def _resolve_datasets(dir_path: str) -> Dir: ) -def _assert_dir_is_empty( - output_dir: Dir, append: bool = False, overwrite: bool = False, storage_options: Optional[Dict] = {} -) -> None: +def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool = False) -> None: if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir isn't a `Dir` Object.") @@ -225,16 +218,20 @@ def _assert_dir_is_empty( obj = parse.urlparse(output_dir.url) - if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: - raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") - try: - object_list = list_directory(output_dir.url, storage_options=storage_options) - except FileNotFoundError: - return + s3 = boto3.client("s3") + + objects = s3.list_objects_v2( + Bucket=obj.netloc, + Delimiter="/", + Prefix=obj.path.lstrip("/").rstrip("/") + "/", + ) # We aren't alloweing to add more data - if object_list is not None and len(object_list) > 0: + # TODO: Add support for `append` and `overwrite`. + if objects["KeyCount"] > 0: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains data and datasets are meant to be immutable." "\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" @@ -242,10 +239,7 @@ def _assert_dir_is_empty( def _assert_dir_has_index_file( - output_dir: Dir, - mode: Optional[Literal["append", "overwrite"]] = None, - use_checkpoint: bool = False, - storage_options: Optional[Dict] = {}, + output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False ) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") @@ -289,19 +283,29 @@ def _assert_dir_has_index_file( obj = parse.urlparse(output_dir.url) - if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: - raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") + if obj.scheme != "s3": + raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") + + s3 = boto3.client("s3") - objects_list = [] - with suppress(FileNotFoundError): - objects_list = list_directory(output_dir.url, storage_options=storage_options) + prefix = obj.path.lstrip("/").rstrip("/") + "/" + + objects = s3.list_objects_v2( + Bucket=obj.netloc, + Delimiter="/", + Prefix=prefix, + ) # No files are found in this folder - if objects_list is None or len(objects_list) == 0: + if objects["KeyCount"] == 0: return # Check the index file exists - has_index_file = does_file_exist(os.path.join(output_dir.url, "index.json"), storage_options=storage_options) + try: + s3.head_object(Bucket=obj.netloc, Key=os.path.join(prefix, "index.json")) + has_index_file = True + except botocore.exceptions.ClientError: + has_index_file = False if has_index_file and mode is None: raise RuntimeError( @@ -310,8 +314,13 @@ def _assert_dir_has_index_file( "\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) + # Delete all the files (including the index file in overwrite mode) + bucket_name = obj.netloc + s3 = boto3.resource("s3") + if mode == "overwrite" or (mode is None and not use_checkpoint): - remove_file_or_directory(output_dir.url, storage_options=storage_options) + for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): + s3.Object(bucket_name, obj.key).delete() def _get_lightning_cloud_url() -> str: diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index ef111541..cc3ef9b2 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -109,16 +109,20 @@ def fn(*_, **__): remove_queue = mock.MagicMock() + s3_client = mock.MagicMock() + called = False - def copy_file(local_filepath, *args, **kwargs): + def copy_file(local_filepath, *args): nonlocal called called = True from shutil import copyfile copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath))) - monkeypatch.setattr(data_processor_module, "upload_file_or_directory", copy_file) + s3_client.client.upload_file = copy_file + + monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client)) assert os.listdir(remote_output_dir) == [] @@ -213,28 +217,32 @@ def test_wait_for_disk_usage_higher_than_threshold(): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -def test_wait_for_file_to_exist(monkeypatch): +def test_wait_for_file_to_exist(): + import botocore + + s3 = mock.MagicMock() + obj = mock.MagicMock() raise_error = [True, True, False] def fn(*_, **__): value = raise_error.pop(0) if value: - raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception + raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") return - monkeypatch.setattr(data_processor_module, "does_file_exist", fn) + s3.client.head_object = fn - _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) + _wait_for_file_to_exist(s3, obj, sleep_time=0.01) assert len(raise_error) == 0 def fn(*_, **__): raise ValueError("HERE") - monkeypatch.setattr(data_processor_module, "does_file_exist", fn) + s3.client.head_object = fn with pytest.raises(ValueError, match="HERE"): - _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) + _wait_for_file_to_exist(s3, obj, sleep_time=0.01) def test_cache_dir_cleanup(tmpdir, monkeypatch): @@ -1016,10 +1024,11 @@ def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_map_error_when_not_empty(monkeypatch): - def mock_list_directory(*args, **kwargs): - return ["a.txt", "b.txt"] - - monkeypatch.setattr(resolver, "list_directory", mock_list_directory) + boto3 = mock.MagicMock() + client_s3_mock = mock.MagicMock() + client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} + boto3.client.return_value = client_s3_mock + monkeypatch.setattr(resolver, "boto3", boto3) with pytest.raises(RuntimeError, match="data and datasets are meant to be immutable"): map( diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py new file mode 100644 index 00000000..78ea919d --- /dev/null +++ b/tests/streaming/test_client.py @@ -0,0 +1,97 @@ +import sys +from time import sleep, time +from unittest import mock + +import pytest +from litdata.streaming import client + + +def test_s3_client_with_storage_options(monkeypatch): + boto3 = mock.MagicMock() + monkeypatch.setattr(client, "boto3", boto3) + + botocore = mock.MagicMock() + monkeypatch.setattr(client, "botocore", botocore) + + storage_options = { + "region_name": "us-west-2", + "endpoint_url": "https://custom.endpoint", + "config": botocore.config.Config(retries={"max_attempts": 100}), + } + s3_client = client.S3Client(storage_options=storage_options) + + assert s3_client.client + + boto3.client.assert_called_with( + "s3", + region_name="us-west-2", + endpoint_url="https://custom.endpoint", + config=botocore.config.Config(retries={"max_attempts": 100}), + ) + + s3_client = client.S3Client() + + assert s3_client.client + + boto3.client.assert_called_with( + "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) + ) + + +def test_s3_client_without_cloud_space_id(monkeypatch): + boto3 = mock.MagicMock() + monkeypatch.setattr(client, "boto3", boto3) + + botocore = mock.MagicMock() + monkeypatch.setattr(client, "botocore", botocore) + + instance_metadata_provider = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) + + instance_metadata_fetcher = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) + + s3 = client.S3Client(1) + assert s3.client + assert s3.client + assert s3.client + assert s3.client + assert s3.client + + boto3.client.assert_called_once() + + +@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") +@pytest.mark.parametrize("use_shared_credentials", [False, True, None]) +def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): + boto3 = mock.MagicMock() + monkeypatch.setattr(client, "boto3", boto3) + + botocore = mock.MagicMock() + monkeypatch.setattr(client, "botocore", botocore) + + if isinstance(use_shared_credentials, bool): + monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") + monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") + monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") + + instance_metadata_provider = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) + + instance_metadata_fetcher = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) + + s3 = client.S3Client(1) + assert s3.client + assert s3.client + boto3.client.assert_called_once() + sleep(1 - (time() - s3._last_time)) + assert s3.client + assert s3.client + assert len(boto3.client._mock_mock_calls) == 6 + sleep(1 - (time() - s3._last_time)) + assert s3.client + assert s3.client + assert len(boto3.client._mock_mock_calls) == 9 + + assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3 diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index b87f9ffd..ef93021c 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -889,7 +889,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) -@pytest.mark.timeout(120) +@pytest.mark.timeout(60) @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 97368c0c..7c79afe5 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -1,19 +1,84 @@ import os +from unittest import mock from unittest.mock import MagicMock from litdata.streaming.downloader import ( + AzureDownloader, + GCPDownloader, LocalDownloaderWithCache, + S3Downloader, shutil, + subprocess, ) +def test_s3_downloader_fast(tmpdir, monkeypatch): + monkeypatch.setattr(os, "system", MagicMock(return_value=0)) + popen_mock = MagicMock() + monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) + downloader = S3Downloader(tmpdir, tmpdir, []) + downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) + popen_mock.wait.assert_called() + + +@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True) +def test_gcp_downloader(tmpdir, monkeypatch, google_mock): + # Create mock objects + mock_client = MagicMock() + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.download_to_filename = MagicMock() + + # Patch the storage client to return the mock client + google_mock.cloud.storage.Client = MagicMock(return_value=mock_client) + + # Configure the mock client to return the mock bucket and blob + mock_client.bucket = MagicMock(return_value=mock_bucket) + mock_bucket.blob = MagicMock(return_value=mock_blob) + + # Initialize the downloader + storage_options = {"project": "DUMMY_PROJECT"} + downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options) + local_filepath = os.path.join(tmpdir, "a.txt") + downloader.download_file("gs://random_bucket/a.txt", local_filepath) + + # Assert that the correct methods were called + google_mock.cloud.storage.Client.assert_called_with(**storage_options) + mock_client.bucket.assert_called_with("random_bucket") + mock_bucket.blob.assert_called_with("a.txt") + mock_blob.download_to_filename.assert_called_with(local_filepath) + + +@mock.patch("litdata.streaming.downloader._AZURE_STORAGE_AVAILABLE", True) +def test_azure_downloader(tmpdir, monkeypatch, azure_mock): + mock_blob = MagicMock() + mock_blob_data = MagicMock() + mock_blob.download_blob.return_value = mock_blob_data + service_mock = MagicMock() + service_mock.get_blob_client.return_value = mock_blob + + azure_mock.storage.blob.BlobServiceClient = MagicMock(return_value=service_mock) + + # Initialize the downloader + storage_options = {"project": "DUMMY_PROJECT"} + downloader = AzureDownloader("azure://random_bucket", tmpdir, [], storage_options) + local_filepath = os.path.join(tmpdir, "a.txt") + downloader.download_file("azure://random_bucket/a.txt", local_filepath) + + # Assert that the correct methods were called + azure_mock.storage.blob.BlobServiceClient.assert_called_with(**storage_options) + service_mock.get_blob_client.assert_called_with(container="random_bucket", blob="a.txt") + mock_blob.download_blob.assert_called() + mock_blob_data.readinto.assert_called() + + def test_download_with_cache(tmpdir, monkeypatch): # Create a file to download/cache with open("a.txt", "w") as f: f.write("hello") try: - local_downloader = LocalDownloaderWithCache("file", tmpdir, tmpdir, []) + local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, []) shutil_mock = MagicMock() os_mock = MagicMock() monkeypatch.setattr(shutil, "copy", shutil_mock) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 699a39c4..90729ffb 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -302,54 +302,52 @@ def print_fn(msg, file=None): def test_assert_dir_is_empty(monkeypatch): - def mock_list_directory(*args, **kwargs): - return ["a.txt", "b.txt"] - - def mock_empty_list_directory(*args, **kwargs): - return [] - - monkeypatch.setattr(resolver, "list_directory", mock_list_directory) + boto3 = mock.MagicMock() + client_s3_mock = mock.MagicMock() + client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} + boto3.client.return_value = client_s3_mock + resolver.boto3 = boto3 with pytest.raises(RuntimeError, match="The provided output_dir"): resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) - monkeypatch.setattr(resolver, "list_directory", mock_empty_list_directory) + client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} + boto3.client.return_value = client_s3_mock + resolver.boto3 = boto3 resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) def test_assert_dir_has_index_file(monkeypatch): - def mock_list_directory_0(*args, **kwargs): - return [] - - def mock_list_directory_1(*args, **kwargs): - return ["a.txt", "b.txt"] + boto3 = mock.MagicMock() + client_s3_mock = mock.MagicMock() + client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} + boto3.client.return_value = client_s3_mock + resolver.boto3 = boto3 - def mock_list_directory_2(*args, **kwargs): - return ["index.json"] + with pytest.raises(RuntimeError, match="The provided output_dir"): + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - def mock_does_file_exist_1(*args, **kwargs): - raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception + client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} + boto3.client.return_value = client_s3_mock + resolver.boto3 = boto3 - def mock_does_file_exist_2(*args, **kwargs): - return True + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - def mock_remove_file_or_directory(*args, **kwargs): - return + client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - monkeypatch.setattr(resolver, "list_directory", mock_list_directory_0) - monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_1) - monkeypatch.setattr(resolver, "remove_file_or_directory", mock_remove_file_or_directory) + def head_object(*args, **kwargs): + import botocore - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") - monkeypatch.setattr(resolver, "list_directory", mock_list_directory_2) - monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_2) + client_s3_mock.head_object = head_object + boto3.client.return_value = client_s3_mock + resolver.boto3 = boto3 - with pytest.raises(RuntimeError, match="The provided output_dir"): - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode="overwrite") + boto3.resource.assert_called() def test_resolve_dir_absolute(tmp_path, monkeypatch): @@ -369,10 +367,3 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch): link.symlink_to(src) assert link.resolve() == src assert resolver._resolve_dir(str(link)).path == str(src) - - -def test_resolve_dir_unsupported_cloud_provider(monkeypatch, tmp_path): - """Test that the unsupported cloud provider is handled correctly.""" - test_dir = "some-random-cloud-provider://some-random-bucket" - with pytest.raises(ValueError, match="The provided dir_path"): - resolver._resolve_dir(test_dir)