From 96ec15de0cbc2f47dcb36f369b580a99d2ce362e Mon Sep 17 00:00:00 2001 From: deependu Date: Sun, 1 Sep 2024 18:46:18 +0530 Subject: [PATCH 01/43] fsspec basic setup done and working for s3 --- requirements.txt | 4 +++ src/litdata/streaming/downloader.py | 48 +++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 06a629a0..0fdc59c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,7 @@ filelock numpy boto3 requests +fsspec +fsspec[s3] # aws s3 +fsspec[gs] # google cloud storage +fsspec[abfs] # azure blob diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index ffdbe193..9a69a024 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -18,6 +18,7 @@ from typing import Any, Dict, List, Optional from urllib import parse +import fsspec from filelock import FileLock, Timeout from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME @@ -26,8 +27,17 @@ class Downloader(ABC): def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + self, + cloud_provider: str, + remote_dir: str, + cache_dir: str, + chunks: List[Dict[str, Any]], + storage_options: Optional[Dict] = {}, ): + print("-"*100) + print(f"{cloud_provider=}") + print("-" * 100) + self.fs = fsspec.filesystem(cloud_provider) self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks @@ -188,10 +198,42 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: } +_DOWNLOADERS = { + "s3://": 's3', + "gs://": 'gs', + "azure://": 'abfs', + "local:": 'file', + "": 'file', +} + + +class FsspecDownloader(Downloader): + def __init__( + self, + cloud_provider: str, + remote_dir: str, + cache_dir: str, + chunks: List[Dict[str, Any]], + storage_options: Dict | None = {}, + ): + remote_dir = remote_dir.replace("local:", "") + super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) + + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + if os.path.exists(local_filepath) or remote_filepath == local_filepath: + return + try: + with FileLock(local_filepath + ".lock", timeout=3): + self.fs.get(remote_filepath, local_filepath, recursive=True, **self._storage_options) + except Timeout: + # another process is responsible to download that file, continue + pass + + def get_downloader_cls( remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ) -> Downloader: - for k, cls in _DOWNLOADERS.items(): + for k, fs_cloud_provider in _DOWNLOADERS.items(): if str(remote_dir).startswith(k): - return cls(remote_dir, cache_dir, chunks, storage_options) + return FsspecDownloader(fs_cloud_provider, remote_dir, cache_dir, chunks, storage_options) raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") From 45b59aea51a8efd6ba31331d8e1156ad696e9692 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 13:20:52 +0000 Subject: [PATCH 02/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 9a69a024..2c46a5fb 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -34,7 +34,7 @@ def __init__( chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}, ): - print("-"*100) + print("-" * 100) print(f"{cloud_provider=}") print("-" * 100) self.fs = fsspec.filesystem(cloud_provider) @@ -199,11 +199,11 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: _DOWNLOADERS = { - "s3://": 's3', - "gs://": 'gs', - "azure://": 'abfs', - "local:": 'file', - "": 'file', + "s3://": "s3", + "gs://": "gs", + "azure://": "abfs", + "local:": "file", + "": "file", } From 74dae218244bb7716e44d87c8ff4159528b4f408 Mon Sep 17 00:00:00 2001 From: deependu Date: Mon, 2 Sep 2024 19:13:06 +0530 Subject: [PATCH 03/43] fix storage option in fsspec --- src/litdata/streaming/downloader.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 2c46a5fb..8c4d8bc9 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -34,14 +34,10 @@ def __init__( chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}, ): - print("-" * 100) - print(f"{cloud_provider=}") - print("-" * 100) - self.fs = fsspec.filesystem(cloud_provider) self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks - self._storage_options = storage_options or {} + self.fs = fsspec.filesystem(cloud_provider, **storage_options) def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] @@ -224,7 +220,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: return try: with FileLock(local_filepath + ".lock", timeout=3): - self.fs.get(remote_filepath, local_filepath, recursive=True, **self._storage_options) + self.fs.get(remote_filepath, local_filepath, recursive=True) except Timeout: # another process is responsible to download that file, continue pass From fcb4d955df2eb2d75d79107b86a14e7334e2464d Mon Sep 17 00:00:00 2001 From: deependu Date: Mon, 2 Sep 2024 21:20:15 +0530 Subject: [PATCH 04/43] pass down `storage_options` in dataset utilities --- src/litdata/streaming/dataset.py | 6 ++++-- src/litdata/utilities/dataset_utilities.py | 13 ++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 5c57cd69..fb655fd5 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -155,7 +155,8 @@ def set_epoch(self, current_epoch: int) -> None: def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( - input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url + input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, + storage_options=self.storage_options, ) if cache_path is not None: self.input_dir.path = cache_path @@ -438,7 +439,8 @@ def _validate_state_dict(self) -> None: # In this case, validate the cache folder is the same. if _should_replace_path(state["input_dir_path"]): cache_path = _try_create_cache_dir( - input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"] + input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"], + storage_options=self.storage_options, ) if cache_path != self.input_dir.path: raise ValueError( diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 08194e51..88b7630a 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -38,7 +38,10 @@ def subsample_streaming_dataset( # Make sure input_dir contains cache path and remote url if _should_replace_path(input_dir.path): - cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url) + cache_path = _try_create_cache_dir( + input_dir=input_dir.path if input_dir.path else input_dir.url, + storage_options=storage_options + ) if cache_path is not None: input_dir.path = cache_path @@ -93,7 +96,7 @@ def _should_replace_path(path: Optional[str]) -> bool: return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") -def _read_updated_at(input_dir: Optional[Dir]) -> str: +def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict]) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" index_json_content = None @@ -107,7 +110,7 @@ def _read_updated_at(input_dir: Optional[Dir]) -> str: # download index.json file and read last_updation_timestamp with tempfile.TemporaryDirectory() as tmp_directory: temp_index_filepath = os.path.join(tmp_directory, _INDEX_FILENAME) - downloader = get_downloader_cls(input_dir.url, tmp_directory, []) + downloader = get_downloader_cls(input_dir.url, tmp_directory, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath) index_json_content = load_index_file(tmp_directory) @@ -132,9 +135,9 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str]) -> Optional[str]: +def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict]) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) - updated_at = _read_updated_at(resolved_input_dir) + updated_at = _read_updated_at(resolved_input_dir, storage_options) if updated_at == "0" and input_dir is not None: updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324 From 3080c2cdf2652d63a0e857ae6e400624a453bace Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:50:56 +0000 Subject: [PATCH 05/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/dataset_utilities.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 88b7630a..a7ff6840 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -39,8 +39,7 @@ def subsample_streaming_dataset( # Make sure input_dir contains cache path and remote url if _should_replace_path(input_dir.path): cache_path = _try_create_cache_dir( - input_dir=input_dir.path if input_dir.path else input_dir.url, - storage_options=storage_options + input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options ) if cache_path is not None: input_dir.path = cache_path From c31e259d3a6d8ac5c1f676bcf5c76cd360fc7777 Mon Sep 17 00:00:00 2001 From: deependu Date: Tue, 3 Sep 2024 22:01:11 +0530 Subject: [PATCH 06/43] tested successfully on S3 and GS for (mode= none | append | overwrite), checkpoint, merge_datasets. --- src/litdata/constants.py | 1 + src/litdata/processing/data_processor.py | 131 ++++++++++------------- src/litdata/processing/functions.py | 28 ++--- src/litdata/processing/utilities.py | 39 ++----- src/litdata/streaming/downloader.py | 83 ++++++++++++-- src/litdata/streaming/resolver.py | 63 ++++------- 6 files changed, 172 insertions(+), 173 deletions(-) diff --git a/src/litdata/constants.py b/src/litdata/constants.py index dab8f7ce..8f8f0f7a 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -86,3 +86,4 @@ _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" _IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None)) _ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0"))) +_SUPPORTED_CLOUD_PROVIDERS = ["s3", "gs", "azure"] diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 946e68c6..6b3d1a5c 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -32,8 +32,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse -import boto3 -import botocore import numpy as np import torch @@ -42,14 +40,21 @@ _ENABLE_STATUS, _INDEX_FILENAME, _IS_IN_STUDIO, + _SUPPORTED_CLOUD_PROVIDERS, _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader -from litdata.processing.utilities import _create_dataset, download_directory_from_S3, remove_uuid_from_filename +from litdata.processing.utilities import _create_dataset, remove_uuid_from_filename from litdata.streaming import Cache from litdata.streaming.cache import Dir -from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.downloader import ( + does_file_exist, + download_file_or_directory, + get_cloud_provider, + remove_file_or_directory, + upload_file_or_directory, +) from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import _resolve_dir from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads @@ -96,14 +101,16 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any: - """This function check.""" +def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5) -> Any: + """This function check if a file exists on the remote storage. If not, it waits for a while and tries again.""" + cloud_provider = get_cloud_provider(remote_filepath) while True: try: - return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/")) - except botocore.exceptions.ClientError as e: - if "the HeadObject operation: Not Found" in str(e): + return does_file_exist(remote_filepath, cloud_provider) + except Exception as e: + if wait_for_count > 0: sleep(sleep_time) + wait_for_count -= 1 else: raise e @@ -120,7 +127,6 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: """This function is used to download data from a remote directory to a cache directory to optimise reading.""" - s3 = S3Client() while True: # 2. Fetch from the queue @@ -156,13 +162,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue obj = parse.urlparse(path) - if obj.scheme == "s3": + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: dirpath = os.path.dirname(local_path) os.makedirs(dirpath, exist_ok=True) - - with open(local_path, "wb") as f: - s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + download_file_or_directory(path, local_path) elif os.path.isfile(path): if not path.startswith("/teamspace/studios/this_studio"): @@ -202,8 +206,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ """This function is used to upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) - if obj.scheme == "s3": - s3 = S3Client() + is_remote = obj.scheme in _SUPPORTED_CLOUD_PROVIDERS while True: data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get() @@ -223,7 +226,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ if not local_filepath.startswith(cache_dir): local_filepath = os.path.join(cache_dir, local_filepath) - if obj.scheme == "s3": + if is_remote: try: output_filepath = str(obj.path).lstrip("/") @@ -235,12 +238,8 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = os.path.join(output_filepath, local_filepath.replace(tmpdir, "")[1:]) output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints - - s3.client.upload_file( - local_filepath, - obj.netloc, - output_filepath, - ) + remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + output_filepath + upload_file_or_directory(local_filepath, remote_filepath) except Exception as e: print(e) @@ -842,10 +841,11 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra else: local_filepath = os.path.join(cache_dir, _INDEX_FILENAME) - if obj.scheme == "s3": - s3 = S3Client() - s3.client.upload_file( - local_filepath, obj.netloc, os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)) + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + upload_file_or_directory( + local_filepath, + remote_filepath + os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)), ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) @@ -863,11 +863,12 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra assert output_dir_path remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) - if obj.scheme == "s3": - obj = parse.urlparse(remote_filepath) - _wait_for_file_to_exist(s3, obj) - with open(node_index_filepath, "wb") as f: - s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + _wait_for_file_to_exist(remote_filepath) + download_file_or_directory( + remote_filepath, + node_index_filepath, + ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(remote_filepath, node_index_filepath) @@ -1239,21 +1240,12 @@ def _cleanup_checkpoints(self) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" - - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc - s3 = boto3.resource("s3") - - checkpoint_prefix = os.path.join(prefix, ".checkpoints") - - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=checkpoint_prefix): - s3.Object(bucket_name, obj.key).delete() + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) + with suppress(FileNotFoundError): + remove_file_or_directory(os.path.join(self.output_dir.url, ".checkpoints")) def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: if not self.use_checkpoint: @@ -1279,24 +1271,18 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - # TODO: Add support for all cloud providers - - s3 = S3Client() - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) # write config.json file to temp directory and upload it to s3 with tempfile.TemporaryDirectory() as temp_dir: temp_file_name = os.path.join(temp_dir, "config.json") with open(temp_file_name, "w") as f: json.dump(config, f) - s3.client.upload_file( - temp_file_name, - obj.netloc, - os.path.join(prefix, "config.json"), + upload_file_or_directory( + temp_file_name, os.path.join(self.output_dir.url, ".checkpoints", "config.json") ) except Exception as e: print(e) @@ -1347,26 +1333,23 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - # TODO: Add support for all cloud providers - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" - - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) # download all the checkpoint files in tempdir and read them with tempfile.TemporaryDirectory() as temp_dir: - saved_file_dir = download_directory_from_S3(bucket_name, prefix, temp_dir) - - if not os.path.exists(os.path.join(saved_file_dir, "config.json")): + try: + download_file_or_directory(os.path.join(self.output_dir.url, ".checkpoints/"), temp_dir) + except FileNotFoundError: + return + if not os.path.exists(os.path.join(temp_dir, "config.json")): # if the config.json file doesn't exist, we don't have any checkpoint saved return # read the config.json file - with open(os.path.join(saved_file_dir, "config.json")) as f: + with open(os.path.join(temp_dir, "config.json")) as f: config = json.load(f) if config["num_workers"] != self.num_workers: @@ -1380,11 +1363,11 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: checkpoint_file_names = [f"checkpoint-{worker_idx}.json" for worker_idx in range(self.num_workers)] for i, checkpoint_file_name in enumerate(checkpoint_file_names): - if not os.path.exists(os.path.join(saved_file_dir, checkpoint_file_name)): + if not os.path.exists(os.path.join(temp_dir, checkpoint_file_name)): # if the checkpoint file doesn't exist, we don't have any checkpoint saved for this worker continue - with open(os.path.join(saved_file_dir, checkpoint_file_name)) as f: + with open(os.path.join(temp_dir, checkpoint_file_name)) as f: checkpoint = json.load(f) self.checkpoint_chunks_info[i] = checkpoint["chunks"] diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 7ce32e58..34948fc9 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -27,7 +27,7 @@ import torch -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -36,8 +36,8 @@ optimize_dns_context, read_index_file_content, ) -from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.downloader import copy_file_or_directory, upload_file_or_directory from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import ( Dir, @@ -53,7 +53,7 @@ def _is_remote_file(path: str) -> bool: obj = parse.urlparse(path) - return obj.scheme in ["s3", "gcs"] + return obj.scheme in _SUPPORTED_CLOUD_PROVIDERS def _get_indexed_paths(data: Any) -> Dict[int, str]: @@ -595,15 +595,10 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: shutil.copyfile(input_filepath, output_filepath) elif output_dir.url and copy_info.input_dir.url: - input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename)) - output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename)) - - s3 = S3Client() - s3.client.copy( - {"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")}, - output_obj.netloc, - output_obj.path.lstrip("/"), - ) + input_obj = os.path.join(copy_info.input_dir.url, copy_info.old_filename) + output_obj = os.path.join(output_dir.url, copy_info.new_filename) + + copy_file_or_directory(input_obj, output_obj) else: raise NotImplementedError @@ -619,11 +614,4 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None: f.flush() - obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME)) - - s3 = S3Client() - s3.client.upload_file( - f.name, - obj.netloc, - obj.path.lstrip("/"), - ) + upload_file_or_directory(f.name, os.path.join(output_dir.url, _INDEX_FILENAME)) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 939c9e66..4c61f48d 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -21,11 +21,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -import boto3 -import botocore - -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS from litdata.streaming.cache import Dir +from litdata.streaming.downloader import download_file_or_directory def _create_dataset( @@ -201,27 +199,24 @@ def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: # download the index file from s3, and read it obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.") - - # TODO: Add support for all cloud providers - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.path}." + ) # Check the index file exists try: # Create a temporary file with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file: temp_file_name = temp_file.name - s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name) + download_file_or_directory(os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name) # Read data from the temporary file with open(temp_file_name) as temp_file: data = json.load(temp_file) # Delete the temporary file os.remove(temp_file_name) return data - except botocore.exceptions.ClientError: + except Exception as _e: return None @@ -257,21 +252,3 @@ def remove_uuid_from_filename(filepath: str) -> str: # uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character return filepath[:-38] + ".json" - - -def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str: - s3_resource = boto3.resource("s3") - bucket = s3_resource.Bucket(bucket_name) - - saved_file_dir = "." - - for obj in bucket.objects.filter(Prefix=remote_directory_name): - local_filename = os.path.join(local_directory_name, obj.key) - - if not os.path.exists(os.path.dirname(local_filename)): - os.makedirs(os.path.dirname(local_filename)) - with open(local_filename, "wb") as f: - s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f) - saved_file_dir = os.path.dirname(local_filename) - - return saved_file_dir diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 8c4d8bc9..56ee7350 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -15,7 +15,7 @@ import shutil import subprocess from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from urllib import parse import fsspec @@ -34,6 +34,10 @@ def __init__( chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}, ): + print("-" * 80) + print(f"{cloud_provider=}") + print("-" * 80) + self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks @@ -185,13 +189,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: super().download_file(remote_filepath, local_filepath) -_DOWNLOADERS = { - "s3://": S3Downloader, - "gs://": GCPDownloader, - "azure://": AzureDownloader, - "local:": LocalDownloaderWithCache, - "": LocalDownloader, -} +# _DOWNLOADERS = { +# "s3://": S3Downloader, +# "gs://": GCPDownloader, +# "azure://": AzureDownloader, +# "local:": LocalDownloaderWithCache, +# "": LocalDownloader, +# } _DOWNLOADERS = { @@ -213,6 +217,7 @@ def __init__( storage_options: Dict | None = {}, ): remote_dir = remote_dir.replace("local:", "") + self.is_local = False super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: @@ -226,6 +231,68 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: pass +def does_file_exist( + remote_filepath: str, cloud_provider: Union[str, None] = None, storage_options: Optional[Dict] = {} +) -> bool: + if cloud_provider is None: + cloud_provider = get_cloud_provider(remote_filepath) + + fs = fsspec.filesystem(cloud_provider, **storage_options) + return fs.exists(remote_filepath) + + +def list_directory( + remote_directory: str, + detail: bool = False, + cloud_provider: Union[str, None] = None, + storage_options: Optional[Dict] = {}, +) -> List[str]: + """returns a list of filenames in a remote directory""" + if cloud_provider is None: + cloud_provider = get_cloud_provider(remote_directory) + + fs = fsspec.filesystem(cloud_provider, **storage_options) + return fs.ls(remote_directory, detail=detail) # just return the filenames + + +def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """upload a file to the remote cloud storage""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.get(remote_filepath, local_filepath, recursive=True) + + +def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """upload a file to the remote cloud storage""" + print(f"{local_filepath=}; {remote_filepath=}") + fs_cloud_provider = get_cloud_provider(remote_filepath) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.put(local_filepath, remote_filepath, recursive=True) + + +def copy_file_or_directory( + remote_filepath_src: str, remote_filepath_tg: str, storage_options: Optional[Dict] = {} +) -> None: + """copy a file from src to target on the remote cloud storage""" + fs_cloud_provider = get_cloud_provider(remote_filepath_src) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) + + +def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """remove a file from the remote cloud storage""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.rm(remote_filepath, recursive=True) + + +def get_cloud_provider(remote_filepath: str) -> str: + for k, fs_cloud_provider in _DOWNLOADERS.items(): + if str(remote_filepath).startswith(k): + return fs_cloud_provider + raise ValueError(f"The provided `remote_filepath` {remote_filepath} doesn't have a downloader associated.") + + def get_downloader_cls( remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ) -> Downloader: diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 24a8504e..863a5b26 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -23,10 +23,12 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union from urllib import parse -import boto3 -import botocore - -from litdata.constants import _LIGHTNING_SDK_AVAILABLE +from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_CLOUD_PROVIDERS +from litdata.streaming.downloader import ( + does_file_exist, + list_directory, + remove_file_or_directory, +) if TYPE_CHECKING: from lightning_sdk import Machine @@ -52,8 +54,8 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: assert isinstance(dir_path, str) - cloud_prefixes = ("s3://", "gs://", "azure://") - if dir_path.startswith(cloud_prefixes): + cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS + if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): return Dir(path=None, url=dir_path) if dir_path.startswith("local:"): @@ -218,20 +220,16 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") - - s3 = boto3.client("s3") + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") - objects = s3.list_objects_v2( - Bucket=obj.netloc, - Delimiter="/", - Prefix=obj.path.lstrip("/").rstrip("/") + "/", - ) + try: + object_list = list_directory(output_dir.url) + except FileNotFoundError: + return # We aren't alloweing to add more data - # TODO: Add support for `append` and `overwrite`. - if objects["KeyCount"] > 0: + if len(object_list) > 0: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains data and datasets are meant to be immutable." " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" @@ -283,29 +281,19 @@ def _assert_dir_has_index_file( obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" - - objects = s3.list_objects_v2( - Bucket=obj.netloc, - Delimiter="/", - Prefix=prefix, - ) + objects_list = [] + with suppress(FileNotFoundError): + objects_list = list_directory(output_dir.url) # No files are found in this folder - if objects["KeyCount"] == 0: + if len(objects_list) == 0: return # Check the index file exists - try: - s3.head_object(Bucket=obj.netloc, Key=os.path.join(prefix, "index.json")) - has_index_file = True - except botocore.exceptions.ClientError: - has_index_file = False + has_index_file = does_file_exist(os.path.join(output_dir.url, "index.json")) if has_index_file and mode is None: raise RuntimeError( @@ -314,13 +302,8 @@ def _assert_dir_has_index_file( " HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc - s3 = boto3.resource("s3") - if mode == "overwrite" or (mode is None and not use_checkpoint): - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): - s3.Object(bucket_name, obj.key).delete() + remove_file_or_directory(output_dir.url) def _get_lightning_cloud_url() -> str: From 0c761b1aad137405af316a34dd9b5df647d678d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:32:38 +0000 Subject: [PATCH 07/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/processing/data_processor.py | 6 +++++- src/litdata/streaming/downloader.py | 10 +++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 6b3d1a5c..449948f1 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -102,7 +102,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5) -> Any: - """This function check if a file exists on the remote storage. If not, it waits for a while and tries again.""" + """This function check if a file exists on the remote storage. + + If not, it waits for a while and tries again. + + """ cloud_provider = get_cloud_provider(remote_filepath) while True: try: diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 56ee7350..9b1f95ec 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -247,7 +247,7 @@ def list_directory( cloud_provider: Union[str, None] = None, storage_options: Optional[Dict] = {}, ) -> List[str]: - """returns a list of filenames in a remote directory""" + """Returns a list of filenames in a remote directory.""" if cloud_provider is None: cloud_provider = get_cloud_provider(remote_directory) @@ -256,14 +256,14 @@ def list_directory( def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """upload a file to the remote cloud storage""" + """Upload a file to the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.get(remote_filepath, local_filepath, recursive=True) def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """upload a file to the remote cloud storage""" + """Upload a file to the remote cloud storage.""" print(f"{local_filepath=}; {remote_filepath=}") fs_cloud_provider = get_cloud_provider(remote_filepath) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) @@ -273,14 +273,14 @@ def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_ def copy_file_or_directory( remote_filepath_src: str, remote_filepath_tg: str, storage_options: Optional[Dict] = {} ) -> None: - """copy a file from src to target on the remote cloud storage""" + """Copy a file from src to target on the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath_src) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """remove a file from the remote cloud storage""" + """Remove a file from the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.rm(remote_filepath, recursive=True) From 2377983ed6bf8047ff4b80253ae1491430fd4f34 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 09:17:30 +0530 Subject: [PATCH 08/43] fixed mypy errors and lock files when uploading/downloading --- src/litdata/streaming/downloader.py | 33 +++++++++++++++++++---------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 9b1f95ec..f22a927d 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -57,7 +57,7 @@ class S3Downloader(Downloader): def __init__( self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ): - super().__init__(remote_dir, cache_dir, chunks, storage_options) + super().__init__("s3",remote_dir, cache_dir, chunks, storage_options) self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 if not self._s5cmd_available: @@ -109,7 +109,7 @@ def __init__( if not _GOOGLE_STORAGE_AVAILABLE: raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) - super().__init__(remote_dir, cache_dir, chunks, storage_options) + super().__init__("gs",remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: from google.cloud import storage @@ -146,7 +146,7 @@ def __init__( if not _AZURE_STORAGE_AVAILABLE: raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) - super().__init__(remote_dir, cache_dir, chunks, storage_options) + super().__init__("abfs",remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: from azure.storage.blob import BlobServiceClient @@ -256,18 +256,29 @@ def list_directory( def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """Upload a file to the remote cloud storage.""" - fs_cloud_provider = get_cloud_provider(remote_filepath) - fs = fsspec.filesystem(fs_cloud_provider, **storage_options) - fs.get(remote_filepath, local_filepath, recursive=True) + """download a file from the remote cloud storage.""" + try: + with FileLock(local_filepath + ".lock", timeout=3): + fs_cloud_provider = get_cloud_provider(remote_filepath) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.get(remote_filepath, local_filepath, recursive=True) + except Timeout: + # another process is responsible to download that file, continue + pass + def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: """Upload a file to the remote cloud storage.""" - print(f"{local_filepath=}; {remote_filepath=}") - fs_cloud_provider = get_cloud_provider(remote_filepath) - fs = fsspec.filesystem(fs_cloud_provider, **storage_options) - fs.put(local_filepath, remote_filepath, recursive=True) + try: + with FileLock(local_filepath + ".lock", timeout=3): + fs_cloud_provider = get_cloud_provider(remote_filepath) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.put(local_filepath, remote_filepath, recursive=True) + except Timeout: + # another process is responsible to upload that file, continue + pass + def copy_file_or_directory( From ffbf51d9aa9623967fa37b5563e0665e0462ac0a Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 14:22:48 +0530 Subject: [PATCH 09/43] update --- src/litdata/streaming/resolver.py | 11 ++++------- tests/conftest.py | 1 + tests/processing/test_functions.py | 2 ++ tests/streaming/test_downloader.py | 2 +- tests/streaming/test_resolver.py | 1 + 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 863a5b26..b61b2ddd 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -90,14 +90,11 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa if target_id is not None and cloudspace.id == target_id: return True - if ( + return bool( cloudspace.display_name is not None and target_name is not None and cloudspace.display_name.lower() == target_name.lower() - ): - return True - - return False + ) def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: @@ -229,7 +226,7 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool return # We aren't alloweing to add more data - if len(object_list) > 0: + if object_list is not None and len(object_list) > 0: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains data and datasets are meant to be immutable." " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" @@ -289,7 +286,7 @@ def _assert_dir_has_index_file( objects_list = list_directory(output_dir.url) # No files are found in this folder - if len(objects_list) == 0: + if objects_list is None or len(objects_list) == 0: return # Check the index file exists diff --git a/tests/conftest.py b/tests/conftest.py index cf5fc0a0..972b8109 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import pytest import torch.distributed + from litdata.streaming.reader import PrepareChunksThread diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 80eec0ba..f8a0ba07 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -1,4 +1,5 @@ import os +import shutil import sys from unittest import mock @@ -10,6 +11,7 @@ from litdata.streaming.cache import Cache from litdata.utilities.encryption import FernetEncryption, RSAEncryption from PIL import Image +import fsspec @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 396dc5d4..43198b4f 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -78,7 +78,7 @@ def test_download_with_cache(tmpdir, monkeypatch): f.write("hello") try: - local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, []) + local_downloader = LocalDownloaderWithCache("file", tmpdir, tmpdir, []) shutil_mock = MagicMock() monkeypatch.setattr(shutil, "copy", shutil_mock) local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt")) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 2c962454..48bf8e4a 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -15,6 +15,7 @@ V1ListClustersResponse, V1ListDataConnectionsResponse, ) + from litdata.streaming import resolver From de8b83bdd349e02f994117ccf38bb064a3f828db Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 16:30:40 +0530 Subject: [PATCH 10/43] fixed test `test_try_create_cache_dir` --- src/litdata/utilities/dataset_utilities.py | 4 ++-- tests/streaming/test_reader.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index a7ff6840..fe500b0c 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -95,7 +95,7 @@ def _should_replace_path(path: Optional[str]) -> bool: return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") -def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict]) -> str: +def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict]={}) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" index_json_content = None @@ -134,7 +134,7 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict]) -> Optional[str]: +def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict]={}) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) updated_at = _read_updated_at(resolved_input_dir, storage_options) diff --git a/tests/streaming/test_reader.py b/tests/streaming/test_reader.py index 9df6aa51..bc028eb8 100644 --- a/tests/streaming/test_reader.py +++ b/tests/streaming/test_reader.py @@ -3,6 +3,7 @@ from time import sleep import numpy as np + from litdata.streaming import reader from litdata.streaming.cache import Cache from litdata.streaming.config import ChunkedIndex From e7123272a2123ead5b7c44becc85a1a23fa76c8e Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 16:49:32 +0530 Subject: [PATCH 11/43] fixed test: `test_reader_chunk_removal` --- src/litdata/streaming/downloader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index f22a927d..40b483bc 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -226,6 +226,9 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: try: with FileLock(local_filepath + ".lock", timeout=3): self.fs.get(remote_filepath, local_filepath, recursive=True) + # remove the lock file + if os.path.exists(local_filepath + ".lock"): + os.remove(local_filepath + ".lock") except Timeout: # another process is responsible to download that file, continue pass From e118ba9c78dade10937ab3115cddf69c50275055 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 18:16:26 +0530 Subject: [PATCH 12/43] all tests passed --- src/litdata/streaming/resolver.py | 2 +- tests/processing/test_data_processor.py | 32 +++++------- tests/streaming/test_downloader.py | 65 +++---------------------- tests/streaming/test_resolver.py | 56 ++++++++++----------- 4 files changed, 48 insertions(+), 107 deletions(-) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index b61b2ddd..a351eb25 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -224,7 +224,7 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool object_list = list_directory(output_dir.url) except FileNotFoundError: return - + print(f"{object_list=}") # We aren't alloweing to add more data if object_list is not None and len(object_list) > 0: raise RuntimeError( diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 543b6909..6b409e0f 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -10,6 +10,7 @@ import pytest import torch from lightning_utilities.core.imports import RequirementCache + from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE from litdata.processing import data_processor as data_processor_module from litdata.processing import functions @@ -109,8 +110,6 @@ def fn(*_, **__): remove_queue = mock.MagicMock() - s3_client = mock.MagicMock() - called = False def copy_file(local_filepath, *args): @@ -120,9 +119,7 @@ def copy_file(local_filepath, *args): copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath))) - s3_client.client.upload_file = copy_file - - monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client)) + monkeypatch.setattr(data_processor_module, "upload_file_or_directory", copy_file) assert os.listdir(remote_output_dir) == [] @@ -217,32 +214,28 @@ def test_wait_for_disk_usage_higher_than_threshold(): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -def test_wait_for_file_to_exist(): - import botocore - - s3 = mock.MagicMock() - obj = mock.MagicMock() +def test_wait_for_file_to_exist(monkeypatch): raise_error = [True, True, False] def fn(*_, **__): value = raise_error.pop(0) if value: - raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception return - s3.client.head_object = fn + monkeypatch.setattr(data_processor_module, "does_file_exist", fn) - _wait_for_file_to_exist(s3, obj, sleep_time=0.01) + _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) assert len(raise_error) == 0 def fn(*_, **__): raise ValueError("HERE") - s3.client.head_object = fn + monkeypatch.setattr(data_processor_module, "does_file_exist", fn) with pytest.raises(ValueError, match="HERE"): - _wait_for_file_to_exist(s3, obj, sleep_time=0.01) + _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) def test_cache_dir_cleanup(tmpdir, monkeypatch): @@ -1025,11 +1018,10 @@ def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_map_error_when_not_empty(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - monkeypatch.setattr(resolver, "boto3", boto3) + def mock_list_directory(*args, **kwargs): + return ["a.txt", "b.txt"] + + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="data and datasets are meant to be immutable"): map( diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 43198b4f..63a1dccf 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -11,65 +11,14 @@ subprocess, ) +# def test_s3_downloader_fast(tmpdir, monkeypatch): +# monkeypatch.setattr(os, "system", MagicMock(return_value=0)) +# popen_mock = MagicMock() +# monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) +# downloader = S3Downloader(tmpdir, tmpdir, []) +# downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) +# popen_mock.wait.assert_called() -def test_s3_downloader_fast(tmpdir, monkeypatch): - monkeypatch.setattr(os, "system", MagicMock(return_value=0)) - popen_mock = MagicMock() - monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) - downloader = S3Downloader(tmpdir, tmpdir, []) - downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) - popen_mock.wait.assert_called() - - -@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True) -def test_gcp_downloader(tmpdir, monkeypatch, google_mock): - # Create mock objects - mock_client = MagicMock() - mock_bucket = MagicMock() - mock_blob = MagicMock() - mock_blob.download_to_filename = MagicMock() - - # Patch the storage client to return the mock client - google_mock.cloud.storage.Client = MagicMock(return_value=mock_client) - - # Configure the mock client to return the mock bucket and blob - mock_client.bucket = MagicMock(return_value=mock_bucket) - mock_bucket.blob = MagicMock(return_value=mock_blob) - - # Initialize the downloader - storage_options = {"project": "DUMMY_PROJECT"} - downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options) - local_filepath = os.path.join(tmpdir, "a.txt") - downloader.download_file("gs://random_bucket/a.txt", local_filepath) - - # Assert that the correct methods were called - google_mock.cloud.storage.Client.assert_called_with(**storage_options) - mock_client.bucket.assert_called_with("random_bucket") - mock_bucket.blob.assert_called_with("a.txt") - mock_blob.download_to_filename.assert_called_with(local_filepath) - - -@mock.patch("litdata.streaming.downloader._AZURE_STORAGE_AVAILABLE", True) -def test_azure_downloader(tmpdir, monkeypatch, azure_mock): - mock_blob = MagicMock() - mock_blob_data = MagicMock() - mock_blob.download_blob.return_value = mock_blob_data - service_mock = MagicMock() - service_mock.get_blob_client.return_value = mock_blob - - azure_mock.storage.blob.BlobServiceClient = MagicMock(return_value=service_mock) - - # Initialize the downloader - storage_options = {"project": "DUMMY_PROJECT"} - downloader = AzureDownloader("azure://random_bucket", tmpdir, [], storage_options) - local_filepath = os.path.join(tmpdir, "a.txt") - downloader.download_file("azure://random_bucket/a.txt", local_filepath) - - # Assert that the correct methods were called - azure_mock.storage.blob.BlobServiceClient.assert_called_with(**storage_options) - service_mock.get_blob_client.assert_called_with(container="random_bucket", blob="a.txt") - mock_blob.download_blob.assert_called() - mock_blob_data.readinto.assert_called() def test_download_with_cache(tmpdir, monkeypatch): diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 48bf8e4a..8a8dc7b8 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -301,52 +301,52 @@ def print_fn(msg, file=None): def test_assert_dir_is_empty(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory(*args, **kwargs): + return ["a.txt", "b.txt"] + def mock_empty_list_directory(*args, **kwargs): + return [] + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="The provided output_dir"): resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + monkeypatch.setattr(resolver, "list_directory", mock_empty_list_directory) resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) def test_assert_dir_has_index_file(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory_0(*args, **kwargs): + return [] - with pytest.raises(RuntimeError, match="The provided output_dir"): - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + def mock_list_directory_1(*args, **kwargs): + return ['a.txt', 'b.txt'] - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory_2(*args, **kwargs): + return ["index.json"] - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + def mock_does_file_exist_1(*args, **kwargs): + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} + def mock_does_file_exist_2(*args, **kwargs): + return True - def head_object(*args, **kwargs): - import botocore + def mock_remove_file_or_directory(*args, **kwargs): + return - raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") - - client_s3_mock.head_object = head_object - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + monkeypatch.setattr(resolver, "list_directory", mock_list_directory_0) + monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_1) + monkeypatch.setattr(resolver, "remove_file_or_directory", mock_remove_file_or_directory) resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - boto3.resource.assert_called() + monkeypatch.setattr(resolver, "list_directory", mock_list_directory_2) + monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_2) + + with pytest.raises(RuntimeError, match="The provided output_dir"): + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode='overwrite') def test_resolve_dir_absolute(tmp_path, monkeypatch): From d3450dc53f6e5619bd772db2d12005da5299395b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:47:58 +0000 Subject: [PATCH 13/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 10 ++++------ src/litdata/utilities/dataset_utilities.py | 4 ++-- tests/conftest.py | 1 - tests/processing/test_data_processor.py | 3 +-- tests/processing/test_functions.py | 2 -- tests/streaming/test_downloader.py | 6 ------ tests/streaming/test_reader.py | 1 - tests/streaming/test_resolver.py | 9 +++++---- 8 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 40b483bc..82265faa 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -57,7 +57,7 @@ class S3Downloader(Downloader): def __init__( self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ): - super().__init__("s3",remote_dir, cache_dir, chunks, storage_options) + super().__init__("s3", remote_dir, cache_dir, chunks, storage_options) self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 if not self._s5cmd_available: @@ -109,7 +109,7 @@ def __init__( if not _GOOGLE_STORAGE_AVAILABLE: raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) - super().__init__("gs",remote_dir, cache_dir, chunks, storage_options) + super().__init__("gs", remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: from google.cloud import storage @@ -146,7 +146,7 @@ def __init__( if not _AZURE_STORAGE_AVAILABLE: raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) - super().__init__("abfs",remote_dir, cache_dir, chunks, storage_options) + super().__init__("abfs", remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: from azure.storage.blob import BlobServiceClient @@ -259,7 +259,7 @@ def list_directory( def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: - """download a file from the remote cloud storage.""" + """Download a file from the remote cloud storage.""" try: with FileLock(local_filepath + ".lock", timeout=3): fs_cloud_provider = get_cloud_provider(remote_filepath) @@ -270,7 +270,6 @@ def download_file_or_directory(remote_filepath: str, local_filepath: str, storag pass - def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: """Upload a file to the remote cloud storage.""" try: @@ -283,7 +282,6 @@ def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_ pass - def copy_file_or_directory( remote_filepath_src: str, remote_filepath_tg: str, storage_options: Optional[Dict] = {} ) -> None: diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index fe500b0c..c494f1ec 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -95,7 +95,7 @@ def _should_replace_path(path: Optional[str]) -> bool: return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") -def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict]={}) -> str: +def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" index_json_content = None @@ -134,7 +134,7 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict]={}) -> Optional[str]: +def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) updated_at = _read_updated_at(resolved_input_dir, storage_options) diff --git a/tests/conftest.py b/tests/conftest.py index 972b8109..cf5fc0a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import pytest import torch.distributed - from litdata.streaming.reader import PrepareChunksThread diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 6b409e0f..8699f5e0 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -10,7 +10,6 @@ import pytest import torch from lightning_utilities.core.imports import RequirementCache - from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE from litdata.processing import data_processor as data_processor_module from litdata.processing import functions @@ -220,7 +219,7 @@ def test_wait_for_file_to_exist(monkeypatch): def fn(*_, **__): value = raise_error.pop(0) if value: - raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception return monkeypatch.setattr(data_processor_module, "does_file_exist", fn) diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index f8a0ba07..80eec0ba 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -1,5 +1,4 @@ import os -import shutil import sys from unittest import mock @@ -11,7 +10,6 @@ from litdata.streaming.cache import Cache from litdata.utilities.encryption import FernetEncryption, RSAEncryption from PIL import Image -import fsspec @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 63a1dccf..ba1ed8aa 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -1,14 +1,9 @@ import os -from unittest import mock from unittest.mock import MagicMock from litdata.streaming.downloader import ( - AzureDownloader, - GCPDownloader, LocalDownloaderWithCache, - S3Downloader, shutil, - subprocess, ) # def test_s3_downloader_fast(tmpdir, monkeypatch): @@ -20,7 +15,6 @@ # popen_mock.wait.assert_called() - def test_download_with_cache(tmpdir, monkeypatch): # Create a file to download/cache with open("a.txt", "w") as f: diff --git a/tests/streaming/test_reader.py b/tests/streaming/test_reader.py index bc028eb8..9df6aa51 100644 --- a/tests/streaming/test_reader.py +++ b/tests/streaming/test_reader.py @@ -3,7 +3,6 @@ from time import sleep import numpy as np - from litdata.streaming import reader from litdata.streaming.cache import Cache from litdata.streaming.config import ChunkedIndex diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 8a8dc7b8..630411f3 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -15,7 +15,6 @@ V1ListClustersResponse, V1ListDataConnectionsResponse, ) - from litdata.streaming import resolver @@ -303,8 +302,10 @@ def print_fn(msg, file=None): def test_assert_dir_is_empty(monkeypatch): def mock_list_directory(*args, **kwargs): return ["a.txt", "b.txt"] + def mock_empty_list_directory(*args, **kwargs): return [] + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="The provided output_dir"): @@ -320,13 +321,13 @@ def mock_list_directory_0(*args, **kwargs): return [] def mock_list_directory_1(*args, **kwargs): - return ['a.txt', 'b.txt'] + return ["a.txt", "b.txt"] def mock_list_directory_2(*args, **kwargs): return ["index.json"] def mock_does_file_exist_1(*args, **kwargs): - raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception def mock_does_file_exist_2(*args, **kwargs): return True @@ -346,7 +347,7 @@ def mock_remove_file_or_directory(*args, **kwargs): with pytest.raises(RuntimeError, match="The provided output_dir"): resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode='overwrite') + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode="overwrite") def test_resolve_dir_absolute(tmp_path, monkeypatch): From ed0fff8dfb252b7ce7aef375441df7dce3c56a5a Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 20:25:58 +0530 Subject: [PATCH 14/43] update --- requirements.txt | 3 --- requirements/extras.txt | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0fdc59c5..8c58bca4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,3 @@ numpy boto3 requests fsspec -fsspec[s3] # aws s3 -fsspec[gs] # google cloud storage -fsspec[abfs] # azure blob diff --git a/requirements/extras.txt b/requirements/extras.txt index 14bcb7de..69e5fafe 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -5,3 +5,6 @@ pyarrow tqdm lightning-cloud == 0.5.70 # Must be pinned to ensure compatibility google-cloud-storage +fsspec[s3] # aws s3 +fsspec[gs] # google cloud storage +fsspec[abfs] # azure blob From 08236e8022f5ec6373a11969eb197e3db2967dff Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 20:30:02 +0530 Subject: [PATCH 15/43] update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8c58bca4..6adea00f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ torch lightning-utilities filelock numpy -boto3 +# boto3 requests fsspec From 12b049b6a0c081ee3e8a7ce1eee4094a901b92db Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 20:42:14 +0530 Subject: [PATCH 16/43] boto3 stop bothering me --- src/litdata/streaming/client.py | 118 +++++++------- src/litdata/streaming/downloader.py | 244 ++++++++++++++-------------- tests/conftest.py | 1 + 3 files changed, 182 insertions(+), 181 deletions(-) diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index d24803c3..3e6b7ae4 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -1,70 +1,70 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# # Copyright The Lightning AI team. +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. -import os -from time import time -from typing import Any, Dict, Optional +# import os +# from time import time +# from typing import Any, Dict, Optional -import boto3 -import botocore -from botocore.credentials import InstanceMetadataProvider -from botocore.utils import InstanceMetadataFetcher +# import boto3 +# import botocore +# from botocore.credentials import InstanceMetadataProvider +# from botocore.utils import InstanceMetadataFetcher -from litdata.constants import _IS_IN_STUDIO +# from litdata.constants import _IS_IN_STUDIO -class S3Client: - # TODO: Generalize to support more cloud providers. +# class S3Client: +# # TODO: Generalize to support more cloud providers. - def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: - self._refetch_interval = refetch_interval - self._last_time: Optional[float] = None - self._client: Optional[Any] = None - self._storage_options: dict = storage_options or {} +# def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: +# self._refetch_interval = refetch_interval +# self._last_time: Optional[float] = None +# self._client: Optional[Any] = None +# self._storage_options: dict = storage_options or {} - def _create_client(self) -> None: - has_shared_credentials_file = ( - os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" - ) +# def _create_client(self) -> None: +# has_shared_credentials_file = ( +# os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" +# ) - if has_shared_credentials_file or not _IS_IN_STUDIO: - self._client = boto3.client( - "s3", - **{ - "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), - **self._storage_options, - }, - ) - else: - provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) - credentials = provider.load() - self._client = boto3.client( - "s3", - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), - ) +# if has_shared_credentials_file or not _IS_IN_STUDIO: +# self._client = boto3.client( +# "s3", +# **{ +# "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), +# **self._storage_options, +# }, +# ) +# else: +# provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) +# credentials = provider.load() +# self._client = boto3.client( +# "s3", +# aws_access_key_id=credentials.access_key, +# aws_secret_access_key=credentials.secret_key, +# aws_session_token=credentials.token, +# config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), +# ) - @property - def client(self) -> Any: - if self._client is None: - self._create_client() - self._last_time = time() +# @property +# def client(self) -> Any: +# if self._client is None: +# self._create_client() +# self._last_time = time() - # Re-generate credentials for EC2 - if self._last_time is None or (time() - self._last_time) > self._refetch_interval: - self._create_client() - self._last_time = time() +# # Re-generate credentials for EC2 +# if self._last_time is None or (time() - self._last_time) > self._refetch_interval: +# self._create_client() +# self._last_time = time() - return self._client +# return self._client diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 82265faa..7715a791 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -21,8 +21,8 @@ import fsspec from filelock import FileLock, Timeout -from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME -from litdata.streaming.client import S3Client +# from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME +# from litdata.streaming.client import S3Client class Downloader(ABC): @@ -53,125 +53,125 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: pass -class S3Downloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - super().__init__("s3", remote_dir, cache_dir, chunks, storage_options) - self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - - if not self._s5cmd_available: - self._client = S3Client(storage_options=self._storage_options) - - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - obj = parse.urlparse(remote_filepath) - - if obj.scheme != "s3": - raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - - if os.path.exists(local_filepath): - return - - try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - if self._s5cmd_available: - proc = subprocess.Popen( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - ) - proc.wait() - else: - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - # try: - # with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) - except Timeout: - # another process is responsible to download that file, continue - pass - - -class GCPDownloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - if not _GOOGLE_STORAGE_AVAILABLE: - raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) - - super().__init__("gs", remote_dir, cache_dir, chunks, storage_options) - - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - from google.cloud import storage - - obj = parse.urlparse(remote_filepath) - - if obj.scheme != "gs": - raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote_filepath}") - - if os.path.exists(local_filepath): - return - - try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - bucket_name = obj.netloc - key = obj.path - # Remove the leading "/": - if key[0] == "/": - key = key[1:] - - client = storage.Client(**self._storage_options) - bucket = client.bucket(bucket_name) - blob = bucket.blob(key) - blob.download_to_filename(local_filepath) - except Timeout: - # another process is responsible to download that file, continue - pass - - -class AzureDownloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - if not _AZURE_STORAGE_AVAILABLE: - raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) - - super().__init__("abfs", remote_dir, cache_dir, chunks, storage_options) - - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - from azure.storage.blob import BlobServiceClient - - obj = parse.urlparse(remote_filepath) - - if obj.scheme != "azure": - raise ValueError( - f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}" - ) - - if os.path.exists(local_filepath): - return - - try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - service = BlobServiceClient(**self._storage_options) - blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/")) - with open(local_filepath, "wb") as download_file: - blob_data = blob_client.download_blob() - blob_data.readinto(download_file) - - except Timeout: - # another process is responsible to download that file, continue - pass +# class S3Downloader(Downloader): +# def __init__( +# self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} +# ): +# super().__init__("s3", remote_dir, cache_dir, chunks, storage_options) +# self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 + +# if not self._s5cmd_available: +# self._client = S3Client(storage_options=self._storage_options) + +# def download_file(self, remote_filepath: str, local_filepath: str) -> None: +# obj = parse.urlparse(remote_filepath) + +# if obj.scheme != "s3": +# raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") + +# if os.path.exists(local_filepath): +# return + +# try: +# with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): +# if self._s5cmd_available: +# proc = subprocess.Popen( +# f"s5cmd cp {remote_filepath} {local_filepath}", +# shell=True, +# stdout=subprocess.PIPE, +# ) +# proc.wait() +# else: +# from boto3.s3.transfer import TransferConfig + +# extra_args: Dict[str, Any] = {} + +# # try: +# # with FileLock(local_filepath + ".lock", timeout=1): +# if not os.path.exists(local_filepath): +# # Issue: https://github.com/boto/boto3/issues/3113 +# self._client.client.download_file( +# obj.netloc, +# obj.path.lstrip("/"), +# local_filepath, +# ExtraArgs=extra_args, +# Config=TransferConfig(use_threads=False), +# ) +# except Timeout: +# # another process is responsible to download that file, continue +# pass + + +# class GCPDownloader(Downloader): +# def __init__( +# self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} +# ): +# if not _GOOGLE_STORAGE_AVAILABLE: +# raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) + +# super().__init__("gs", remote_dir, cache_dir, chunks, storage_options) + +# def download_file(self, remote_filepath: str, local_filepath: str) -> None: +# from google.cloud import storage + +# obj = parse.urlparse(remote_filepath) + +# if obj.scheme != "gs": +# raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote_filepath}") + +# if os.path.exists(local_filepath): +# return + +# try: +# with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): +# bucket_name = obj.netloc +# key = obj.path +# # Remove the leading "/": +# if key[0] == "/": +# key = key[1:] + +# client = storage.Client(**self._storage_options) +# bucket = client.bucket(bucket_name) +# blob = bucket.blob(key) +# blob.download_to_filename(local_filepath) +# except Timeout: +# # another process is responsible to download that file, continue +# pass + + +# class AzureDownloader(Downloader): +# def __init__( +# self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} +# ): +# if not _AZURE_STORAGE_AVAILABLE: +# raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) + +# super().__init__("abfs", remote_dir, cache_dir, chunks, storage_options) + +# def download_file(self, remote_filepath: str, local_filepath: str) -> None: +# from azure.storage.blob import BlobServiceClient + +# obj = parse.urlparse(remote_filepath) + +# if obj.scheme != "azure": +# raise ValueError( +# f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}" +# ) + +# if os.path.exists(local_filepath): +# return + +# try: +# with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): +# service = BlobServiceClient(**self._storage_options) +# blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/")) +# with open(local_filepath, "wb") as download_file: +# blob_data = blob_client.download_blob() +# blob_data.readinto(download_file) + +# except Timeout: +# # another process is responsible to download that file, continue +# pass class LocalDownloader(Downloader): @@ -214,7 +214,7 @@ def __init__( remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], - storage_options: Dict | None = {}, + storage_options: Optional[Dict] = {}, ): remote_dir = remote_dir.replace("local:", "") self.is_local = False diff --git a/tests/conftest.py b/tests/conftest.py index cf5fc0a0..972b8109 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import pytest import torch.distributed + from litdata.streaming.reader import PrepareChunksThread From d560d91db071c5131cca1d8d73a45ee127b62eb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:13:12 +0000 Subject: [PATCH 17/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 2 -- tests/conftest.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 7715a791..d96f1dda 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -13,10 +13,8 @@ import os import shutil -import subprocess from abc import ABC from typing import Any, Dict, List, Optional, Union -from urllib import parse import fsspec from filelock import FileLock, Timeout diff --git a/tests/conftest.py b/tests/conftest.py index 972b8109..cf5fc0a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import pytest import torch.distributed - from litdata.streaming.reader import PrepareChunksThread From 27644d33af54eb77a096d2bce0da28c9fc05eca8 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 20:47:45 +0530 Subject: [PATCH 18/43] update --- src/litdata/streaming/client.py | 4 ++-- src/litdata/streaming/downloader.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index 3e6b7ae4..24cc431c 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -34,7 +34,7 @@ # def _create_client(self) -> None: # has_shared_credentials_file = ( -# os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" +# os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" # ) # if has_shared_credentials_file or not _IS_IN_STUDIO: @@ -46,7 +46,7 @@ # }, # ) # else: -# provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) +# provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) # credentials = provider.load() # self._client = boto3.client( # "s3", diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index d96f1dda..fd09a36c 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -65,7 +65,7 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: # obj = parse.urlparse(remote_filepath) # if obj.scheme != "s3": -# raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") +# raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") # if os.path.exists(local_filepath): # return @@ -115,7 +115,7 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: # obj = parse.urlparse(remote_filepath) # if obj.scheme != "gs": -# raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote_filepath}") +# raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for {remote_filepath}") # if os.path.exists(local_filepath): # return @@ -153,7 +153,7 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: # if obj.scheme != "azure": # raise ValueError( -# f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}" +# f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for {remote_filepath}" # ) # if os.path.exists(local_filepath): From bf06cf9f10ca6bb44426b861bcdd3a63e4782b98 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 20:51:32 +0530 Subject: [PATCH 19/43] update --- src/litdata/streaming/downloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index fd09a36c..b12a462a 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -19,7 +19,8 @@ import fsspec from filelock import FileLock, Timeout -# from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME +from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME + # from litdata.streaming.client import S3Client From bdc13f4932c45004476144a330fab812156599ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:21:45 +0000 Subject: [PATCH 20/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index b12a462a..9472e854 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -19,7 +19,7 @@ import fsspec from filelock import FileLock, Timeout -from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME +from litdata.constants import _INDEX_FILENAME # from litdata.streaming.client import S3Client From 909b5cbf20f70983f126a0f844773a6d28aea93e Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 4 Sep 2024 21:06:19 +0530 Subject: [PATCH 21/43] update --- tests/streaming/test_client.py | 144 ++++++++++++++++----------------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py index 78ea919d..4f06a9ce 100644 --- a/tests/streaming/test_client.py +++ b/tests/streaming/test_client.py @@ -1,97 +1,97 @@ -import sys -from time import sleep, time -from unittest import mock +# import sys +# from time import sleep, time +# from unittest import mock -import pytest -from litdata.streaming import client +# import pytest +# from litdata.streaming import client -def test_s3_client_with_storage_options(monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) +# def test_s3_client_with_storage_options(monkeypatch): +# boto3 = mock.MagicMock() +# monkeypatch.setattr(client, "boto3", boto3) - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) +# botocore = mock.MagicMock() +# monkeypatch.setattr(client, "botocore", botocore) - storage_options = { - "region_name": "us-west-2", - "endpoint_url": "https://custom.endpoint", - "config": botocore.config.Config(retries={"max_attempts": 100}), - } - s3_client = client.S3Client(storage_options=storage_options) +# storage_options = { +# "region_name": "us-west-2", +# "endpoint_url": "https://custom.endpoint", +# "config": botocore.config.Config(retries={"max_attempts": 100}), +# } +# s3_client = client.S3Client(storage_options=storage_options) - assert s3_client.client +# assert s3_client.client - boto3.client.assert_called_with( - "s3", - region_name="us-west-2", - endpoint_url="https://custom.endpoint", - config=botocore.config.Config(retries={"max_attempts": 100}), - ) +# boto3.client.assert_called_with( +# "s3", +# region_name="us-west-2", +# endpoint_url="https://custom.endpoint", +# config=botocore.config.Config(retries={"max_attempts": 100}), +# ) - s3_client = client.S3Client() +# s3_client = client.S3Client() - assert s3_client.client +# assert s3_client.client - boto3.client.assert_called_with( - "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) - ) +# boto3.client.assert_called_with( +# "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) +# ) -def test_s3_client_without_cloud_space_id(monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) +# def test_s3_client_without_cloud_space_id(monkeypatch): +# boto3 = mock.MagicMock() +# monkeypatch.setattr(client, "boto3", boto3) - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) +# botocore = mock.MagicMock() +# monkeypatch.setattr(client, "botocore", botocore) - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) +# instance_metadata_provider = mock.MagicMock() +# monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) +# instance_metadata_fetcher = mock.MagicMock() +# monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - s3 = client.S3Client(1) - assert s3.client - assert s3.client - assert s3.client - assert s3.client - assert s3.client +# s3 = client.S3Client(1) +# assert s3.client +# assert s3.client +# assert s3.client +# assert s3.client +# assert s3.client - boto3.client.assert_called_once() +# boto3.client.assert_called_once() -@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") -@pytest.mark.parametrize("use_shared_credentials", [False, True, None]) -def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) +# @pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") +# @pytest.mark.parametrize("use_shared_credentials", [False, True, None]) +# def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): +# boto3 = mock.MagicMock() +# monkeypatch.setattr(client, "boto3", boto3) - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) +# botocore = mock.MagicMock() +# monkeypatch.setattr(client, "botocore", botocore) - if isinstance(use_shared_credentials, bool): - monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") - monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") - monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") +# if isinstance(use_shared_credentials, bool): +# monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") +# monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") +# monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) +# instance_metadata_provider = mock.MagicMock() +# monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) +# instance_metadata_fetcher = mock.MagicMock() +# monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - s3 = client.S3Client(1) - assert s3.client - assert s3.client - boto3.client.assert_called_once() - sleep(1 - (time() - s3._last_time)) - assert s3.client - assert s3.client - assert len(boto3.client._mock_mock_calls) == 6 - sleep(1 - (time() - s3._last_time)) - assert s3.client - assert s3.client - assert len(boto3.client._mock_mock_calls) == 9 +# s3 = client.S3Client(1) +# assert s3.client +# assert s3.client +# boto3.client.assert_called_once() +# sleep(1 - (time() - s3._last_time)) +# assert s3.client +# assert s3.client +# assert len(boto3.client._mock_mock_calls) == 6 +# sleep(1 - (time() - s3._last_time)) +# assert s3.client +# assert s3.client +# assert len(boto3.client._mock_mock_calls) == 9 - assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3 +# assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3 From 5671d11dbe10f937c77fe8f44e4ab651588ad9c8 Mon Sep 17 00:00:00 2001 From: deependu Date: Thu, 5 Sep 2024 13:09:17 +0530 Subject: [PATCH 22/43] tested on azure and made sure `storage_option` is working in all cases --- src/litdata/constants.py | 2 +- src/litdata/processing/data_processor.py | 44 +++++++++++++++++------ src/litdata/processing/functions.py | 45 +++++++++++++++++------- src/litdata/processing/utilities.py | 6 ++-- src/litdata/streaming/downloader.py | 3 +- src/litdata/streaming/resolver.py | 17 +++++---- 6 files changed, 83 insertions(+), 34 deletions(-) diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 8f8f0f7a..a3258160 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -86,4 +86,4 @@ _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" _IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None)) _ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0"))) -_SUPPORTED_CLOUD_PROVIDERS = ["s3", "gs", "azure"] +_SUPPORTED_CLOUD_PROVIDERS = ["s3", "gs", "azure", "abfs"] diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 449948f1..e2b3dcb8 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -101,7 +101,9 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5) -> Any: +def _wait_for_file_to_exist( + remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5, storage_options: Optional[Dict] = {} +) -> Any: """This function check if a file exists on the remote storage. If not, it waits for a while and tries again. @@ -110,7 +112,7 @@ def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2, wait_for_ cloud_provider = get_cloud_provider(remote_filepath) while True: try: - return does_file_exist(remote_filepath, cloud_provider) + return does_file_exist(remote_filepath, cloud_provider, storage_options=storage_options) except Exception as e: if wait_for_count > 0: sleep(sleep_time) @@ -129,7 +131,9 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: return -def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: +def _download_data_target( + input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Optional[Dict] = {} +) -> None: """This function is used to download data from a remote directory to a cache directory to optimise reading.""" while True: @@ -170,7 +174,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue dirpath = os.path.dirname(local_path) os.makedirs(dirpath, exist_ok=True) - download_file_or_directory(path, local_path) + download_file_or_directory(path, local_path, storage_options=storage_options) elif os.path.isfile(path): if not path.startswith("/teamspace/studios/this_studio"): @@ -206,7 +210,9 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: os.remove(path) -def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: +def _upload_fn( + upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir, storage_options: Optional[Dict] = {} +) -> None: """This function is used to upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) @@ -243,7 +249,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + output_filepath - upload_file_or_directory(local_filepath, remote_filepath) + upload_file_or_directory(local_filepath, remote_filepath, storage_options=storage_options) except Exception as e: print(e) @@ -420,6 +426,7 @@ def __init__( checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None, checkpoint_next_index: Optional[int] = None, item_loader: Optional[BaseItemLoader] = None, + storage_options: Optional[Dict] = {}, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -454,6 +461,7 @@ def __init__( self.use_checkpoint: bool = use_checkpoint self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info self.checkpoint_next_index: Optional[int] = checkpoint_next_index + self.storage_options = storage_options def run(self) -> None: try: @@ -630,6 +638,7 @@ def _start_downloaders(self) -> None: self.cache_data_dir, to_download_queue, self.ready_to_process_queue, + self.storage_options, ), ) p.start() @@ -669,6 +678,7 @@ def _start_uploaders(self) -> None: self.remove_queue, self.cache_chunks_dir, self.output_dir, + self.storage_options, ), ) p.start() @@ -770,6 +780,7 @@ def __init__( chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, encryption: Optional[Encryption] = None, + storage_options: Optional[Dict] = {}, ): super().__init__() if chunk_size is not None and chunk_bytes is not None: @@ -779,6 +790,7 @@ def __init__( self.chunk_bytes = 1 << 26 if chunk_size is None and chunk_bytes is None else chunk_bytes self.compression = compression self.encryption = encryption + self.storage_options = storage_options @abstractmethod def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @@ -850,6 +862,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra upload_file_or_directory( local_filepath, remote_filepath + os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)), + storage_options=self.storage_options, ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) @@ -868,10 +881,11 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: - _wait_for_file_to_exist(remote_filepath) + _wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options) download_file_or_directory( remote_filepath, node_index_filepath, + storage_options=self.storage_options, ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(remote_filepath, node_index_filepath) @@ -913,6 +927,7 @@ def __init__( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Optional[Dict] = {}, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -938,6 +953,7 @@ def __init__( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: The storage options used by the cloud provider. """ # spawn doesn't work in IPython @@ -974,6 +990,7 @@ def __init__( self.item_loader = item_loader self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} + self.storage_options = storage_options if self.reader is not None and self.weights is not None: raise ValueError("Either the reader or the weights needs to be defined.") @@ -1193,6 +1210,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None, self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None, self.item_loader, + storage_options=self.storage_options, ) worker.start() workers.append(worker) @@ -1249,7 +1267,9 @@ def _cleanup_checkpoints(self) -> None: f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." ) with suppress(FileNotFoundError): - remove_file_or_directory(os.path.join(self.output_dir.url, ".checkpoints")) + remove_file_or_directory( + os.path.join(self.output_dir.url, ".checkpoints"), storage_options=self.storage_options + ) def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: if not self.use_checkpoint: @@ -1286,7 +1306,9 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: with open(temp_file_name, "w") as f: json.dump(config, f) upload_file_or_directory( - temp_file_name, os.path.join(self.output_dir.url, ".checkpoints", "config.json") + temp_file_name, + os.path.join(self.output_dir.url, ".checkpoints", "config.json"), + storage_options=self.storage_options, ) except Exception as e: print(e) @@ -1345,7 +1367,9 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: # download all the checkpoint files in tempdir and read them with tempfile.TemporaryDirectory() as temp_dir: try: - download_file_or_directory(os.path.join(self.output_dir.url, ".checkpoints/"), temp_dir) + download_file_or_directory( + os.path.join(self.output_dir.url, ".checkpoints/"), temp_dir, storage_options=self.storage_options + ) except FileNotFoundError: return if not os.path.exists(os.path.join(temp_dir, "config.json")): diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 34948fc9..eb82cb4a 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -151,8 +151,12 @@ def __init__( compression: Optional[str], encryption: Optional[Encryption] = None, existing_index: Optional[Dict[str, Any]] = None, + storage_options: Optional[Dict] = {}, ): - super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption) + super().__init__( + chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption, + storage_options=storage_options, + ) self._fn = fn self._inputs = inputs self.is_generator = False @@ -199,6 +203,7 @@ def map( error_when_not_empty: bool = False, reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, + storage_options: Optional[Dict] = {}, ) -> None: """This function maps a callable over a collection of inputs, possibly in a distributed way. @@ -219,6 +224,7 @@ def map( Set this to ``False`` if the order in which samples are processed should be preserved. error_when_not_empty: Whether we should error if the output folder isn't empty. batch_size: Group the inputs into batches of batch_size length. + storage_options: The storage options used by the cloud provider. """ if isinstance(inputs, StreamingDataLoader) and batch_size is not None: @@ -257,7 +263,7 @@ def map( ) if error_when_not_empty: - _assert_dir_is_empty(_output_dir) + _assert_dir_is_empty(_output_dir, storage_options=storage_options) if not isinstance(inputs, StreamingDataLoader): input_dir = input_dir or _get_input_dir(inputs) @@ -281,6 +287,7 @@ def map( reorder_files=reorder_files, weights=weights, reader=reader, + storage_options=storage_options, ) with optimize_dns_context(True): return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) @@ -314,6 +321,7 @@ def optimize( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Optional[Dict] = {}, ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. @@ -347,6 +355,7 @@ def optimize( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: The storage options used by the cloud provider. """ if mode is not None and mode not in ["append", "overwrite"]: @@ -401,7 +410,9 @@ def optimize( " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint) + _assert_dir_has_index_file( + _output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options + ) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -417,7 +428,9 @@ def optimize( num_workers = num_workers or _get_default_num_workers() state_dict = {rank: 0 for rank in range(num_workers)} - existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None + existing_index_file_content = read_index_file_content( + _output_dir, storage_options=storage_options + ) if mode == "append" else None if existing_index_file_content is not None: for chunk in existing_index_file_content["chunks"]: @@ -439,6 +452,7 @@ def optimize( use_checkpoint=use_checkpoint, item_loader=item_loader, start_method=start_method, + storage_options=storage_options, ) with optimize_dns_context(True): @@ -451,6 +465,7 @@ def optimize( compression=compression, encryption=encryption, existing_index=existing_index_file_content, + storage_options=storage_options, ) ) return None @@ -519,7 +534,7 @@ class CopyInfo: new_filename: str -def merge_datasets(input_dirs: List[str], output_dir: str) -> None: +def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Optional[Dict] = {}) -> None: """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized dataset. @@ -540,12 +555,14 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") - input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + input_dirs_file_content = [ + read_index_file_content(input_dir, storage_options=storage_options) for input_dir in resolved_input_dirs + ] if any(file_content is None for file_content in input_dirs_file_content): raise ValueError("One of the provided input_dir doesn't have an index file.") - output_dir_file_content = read_index_file_content(resolved_output_dir) + output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options=storage_options) if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") @@ -580,12 +597,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: _tqdm = _get_tqdm_iterator_if_available() for copy_info in _tqdm(copy_infos): - _apply_copy(copy_info, resolved_output_dir) + _apply_copy(copy_info, resolved_output_dir, storage_options=storage_options) - _save_index(index_json, resolved_output_dir) + _save_index(index_json, resolved_output_dir, storage_options=storage_options) -def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: +def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: if output_dir.url is None and copy_info.input_dir.url is None: assert copy_info.input_dir.path assert output_dir.path @@ -598,12 +615,12 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: input_obj = os.path.join(copy_info.input_dir.url, copy_info.old_filename) output_obj = os.path.join(output_dir.url, copy_info.new_filename) - copy_file_or_directory(input_obj, output_obj) + copy_file_or_directory(input_obj, output_obj, storage_options=storage_options) else: raise NotImplementedError -def _save_index(index_json: Dict, output_dir: Dir) -> None: +def _save_index(index_json: Dict, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: if output_dir.url is None: assert output_dir.path with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: @@ -614,4 +631,6 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None: f.flush() - upload_file_or_directory(f.name, os.path.join(output_dir.url, _INDEX_FILENAME)) + upload_file_or_directory( + f.name, os.path.join(output_dir.url, _INDEX_FILENAME), storage_options=storage_options + ) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 4c61f48d..800195ee 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -181,7 +181,7 @@ def _get_work_dir() -> str: return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" -def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: +def read_index_file_content(output_dir: Dir, storage_options: Optional[Dict] = {}) -> Optional[Dict[str, Any]]: """Read the index file content.""" if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir should be a Dir object.") @@ -209,7 +209,9 @@ def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: # Create a temporary file with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file: temp_file_name = temp_file.name - download_file_or_directory(os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name) + download_file_or_directory( + os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name, storage_options=storage_options + ) # Read data from the temporary file with open(temp_file_name) as temp_file: data = json.load(temp_file) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 6d73cf28..a706dfcc 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -211,6 +211,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: "s3://": "s3", "gs://": "gs", "azure://": "abfs", + "abfs://": "abfs", "local:": "file", "": "file", } @@ -256,7 +257,7 @@ def does_file_exist( def list_directory( remote_directory: str, detail: bool = False, - cloud_provider: Union[str, None] = None, + cloud_provider: Optional[str] = None, storage_options: Optional[Dict] = {}, ) -> List[str]: """Returns a list of filenames in a remote directory.""" diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index a351eb25..2b8de6fb 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from urllib import parse from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_CLOUD_PROVIDERS @@ -208,7 +208,9 @@ def _resolve_datasets(dir_path: str) -> Dir: ) -def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool = False) -> None: +def _assert_dir_is_empty( + output_dir: Dir, append: bool = False, overwrite: bool = False, storage_options: Optional[Dict] = {} +) -> None: if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir isn't a Dir Object.") @@ -221,7 +223,7 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") try: - object_list = list_directory(output_dir.url) + object_list = list_directory(output_dir.url, storage_options=storage_options) except FileNotFoundError: return print(f"{object_list=}") @@ -234,7 +236,8 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool def _assert_dir_has_index_file( - output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False + output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False, + storage_options: Optional[Dict] = {} ) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") @@ -283,14 +286,14 @@ def _assert_dir_has_index_file( objects_list = [] with suppress(FileNotFoundError): - objects_list = list_directory(output_dir.url) + objects_list = list_directory(output_dir.url, storage_options=storage_options) # No files are found in this folder if objects_list is None or len(objects_list) == 0: return # Check the index file exists - has_index_file = does_file_exist(os.path.join(output_dir.url, "index.json")) + has_index_file = does_file_exist(os.path.join(output_dir.url, "index.json"), storage_options=storage_options) if has_index_file and mode is None: raise RuntimeError( @@ -300,7 +303,7 @@ def _assert_dir_has_index_file( ) if mode == "overwrite" or (mode is None and not use_checkpoint): - remove_file_or_directory(output_dir.url) + remove_file_or_directory(output_dir.url, storage_options=storage_options) def _get_lightning_cloud_url() -> str: From 2beebc9e73b97fbefca3705e90f6e79d86c9073c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 07:39:57 +0000 Subject: [PATCH 23/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/processing/functions.py | 11 +++++++---- src/litdata/streaming/resolver.py | 6 ++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index eb82cb4a..bd80e0e5 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -154,7 +154,10 @@ def __init__( storage_options: Optional[Dict] = {}, ): super().__init__( - chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption, + chunk_size=chunk_size, + chunk_bytes=chunk_bytes, + compression=compression, + encryption=encryption, storage_options=storage_options, ) self._fn = fn @@ -428,9 +431,9 @@ def optimize( num_workers = num_workers or _get_default_num_workers() state_dict = {rank: 0 for rank in range(num_workers)} - existing_index_file_content = read_index_file_content( - _output_dir, storage_options=storage_options - ) if mode == "append" else None + existing_index_file_content = ( + read_index_file_content(_output_dir, storage_options=storage_options) if mode == "append" else None + ) if existing_index_file_content is not None: for chunk in existing_index_file_content["chunks"]: diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 2b8de6fb..33e40272 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -236,8 +236,10 @@ def _assert_dir_is_empty( def _assert_dir_has_index_file( - output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False, - storage_options: Optional[Dict] = {} + output_dir: Dir, + mode: Optional[Literal["append", "overwrite"]] = None, + use_checkpoint: bool = False, + storage_options: Optional[Dict] = {}, ) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") From dd2e7427c58f5efc6ccfa02209a32984c0e5586d Mon Sep 17 00:00:00 2001 From: deependu Date: Thu, 5 Sep 2024 13:24:43 +0530 Subject: [PATCH 24/43] update --- requirements.txt | 1 + requirements/extras.txt | 1 - tests/processing/test_data_processor.py | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6adea00f..ec443722 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy # boto3 requests fsspec +fsspec[s3] # aws s3 diff --git a/requirements/extras.txt b/requirements/extras.txt index 69e5fafe..32bc2b54 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -5,6 +5,5 @@ pyarrow tqdm lightning-cloud == 0.5.70 # Must be pinned to ensure compatibility google-cloud-storage -fsspec[s3] # aws s3 fsspec[gs] # google cloud storage fsspec[abfs] # azure blob diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 8699f5e0..20a6d280 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -10,6 +10,7 @@ import pytest import torch from lightning_utilities.core.imports import RequirementCache + from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE from litdata.processing import data_processor as data_processor_module from litdata.processing import functions @@ -111,7 +112,7 @@ def fn(*_, **__): called = False - def copy_file(local_filepath, *args): + def copy_file(local_filepath, *args, **kwargs): nonlocal called called = True from shutil import copyfile From f5550693d178115343ebf216d5cc53b2648c1a98 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 07:55:11 +0000 Subject: [PATCH 25/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/processing/test_data_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 20a6d280..493d9b48 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -10,7 +10,6 @@ import pytest import torch from lightning_utilities.core.imports import RequirementCache - from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE from litdata.processing import data_processor as data_processor_module from litdata.processing import functions From 87a95560467fe8aaab8d6618a7c7446e3ca1d62b Mon Sep 17 00:00:00 2001 From: deependu Date: Thu, 5 Sep 2024 14:02:57 +0530 Subject: [PATCH 26/43] update --- src/litdata/streaming/downloader.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index a706dfcc..b1aeb8e2 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -34,9 +34,6 @@ def __init__( chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}, ): - print("-" * 80) - print(f"{cloud_provider=}") - print("-" * 80) self._remote_dir = remote_dir self._cache_dir = cache_dir From 8e9d44866c50f44237a29e98f39f9ed7217ddacd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 08:33:15 +0000 Subject: [PATCH 27/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index b1aeb8e2..fb79cfd3 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -34,7 +34,6 @@ def __init__( chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}, ): - self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks From 555eb19aaecd1bba2ec6460bf2cba6ae6a1aa9b2 Mon Sep 17 00:00:00 2001 From: deependu Date: Thu, 5 Sep 2024 15:19:39 +0530 Subject: [PATCH 28/43] use s5cmd to download files if available --- src/litdata/streaming/downloader.py | 42 ++++++++++++++++++++++++++++- src/litdata/streaming/resolver.py | 2 +- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index fb79cfd3..80005d6c 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -14,8 +14,10 @@ import contextlib import os import shutil +import subprocess from abc import ABC from typing import Any, Dict, List, Optional, Union +from urllib import parse import fsspec from filelock import FileLock, Timeout @@ -24,6 +26,7 @@ # from litdata.streaming.client import S3Client +_USE_S5CMD_FOR_S3 = True class Downloader(ABC): def __init__( @@ -202,6 +205,34 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: # "": LocalDownloader, # } +def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: + + _s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 + + if _s5cmd_available is False: + raise ModuleNotFoundError(str(_s5cmd_available)) + + obj = parse.urlparse(remote_filepath) + + + if obj.scheme != "s3": + raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") + + if os.path.exists(local_filepath): + return + + try: + with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", + shell=True, + stdout=subprocess.PIPE, + ) + proc.wait() + except Timeout: + # another process is responsible to download that file, continue + pass + _DOWNLOADERS = { "s3://": "s3", @@ -225,10 +256,15 @@ def __init__( remote_dir = remote_dir.replace("local:", "") self.is_local = False super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) + self.cloud_provider = cloud_provider + self.use_s5cmd = cloud_provider=='s3' and os.system("s5cmd > /dev/null 2>&1") == 0 def download_file(self, remote_filepath: str, local_filepath: str) -> None: if os.path.exists(local_filepath) or remote_filepath == local_filepath: return + if self.use_s5cmd and _USE_S5CMD_FOR_S3: + download_s3_file_via_s5cmd(remote_filepath, local_filepath) + return try: with FileLock(local_filepath + ".lock", timeout=3): self.fs.get(remote_filepath, local_filepath, recursive=True) @@ -266,9 +302,13 @@ def list_directory( def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: """Download a file from the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + use_s5cmd = fs_cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 + if use_s5cmd and _USE_S5CMD_FOR_S3: + download_s3_file_via_s5cmd(remote_filepath, local_filepath) + return try: with FileLock(local_filepath + ".lock", timeout=3): - fs_cloud_provider = get_cloud_provider(remote_filepath) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.get(remote_filepath, local_filepath, recursive=True) except Timeout: diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 33e40272..b4035851 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -226,7 +226,7 @@ def _assert_dir_is_empty( object_list = list_directory(output_dir.url, storage_options=storage_options) except FileNotFoundError: return - print(f"{object_list=}") + # We aren't alloweing to add more data if object_list is not None and len(object_list) > 0: raise RuntimeError( From 5ef4004b8ae0187a4499ed566611d74cde291803 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 09:51:18 +0000 Subject: [PATCH 29/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 80005d6c..8a4591df 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -28,6 +28,7 @@ _USE_S5CMD_FOR_S3 = True + class Downloader(ABC): def __init__( self, @@ -205,8 +206,8 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: # "": LocalDownloader, # } -def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: +def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: _s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 if _s5cmd_available is False: @@ -214,7 +215,6 @@ def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> Non obj = parse.urlparse(remote_filepath) - if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") @@ -257,7 +257,7 @@ def __init__( self.is_local = False super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) self.cloud_provider = cloud_provider - self.use_s5cmd = cloud_provider=='s3' and os.system("s5cmd > /dev/null 2>&1") == 0 + self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 def download_file(self, remote_filepath: str, local_filepath: str) -> None: if os.path.exists(local_filepath) or remote_filepath == local_filepath: From 67205eab02fa5e3192ebe35f1105658fff433f5b Mon Sep 17 00:00:00 2001 From: deependu Date: Fri, 6 Sep 2024 13:29:10 +0530 Subject: [PATCH 30/43] add default storage_options --- src/litdata/streaming/downloader.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 8a4591df..046361ba 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -243,6 +243,17 @@ def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> Non "": "file", } +_DEFAULT_STORAGE_OPTIONS = { + "s3":{"config_kwargs": {"retries":{"max_attempts": 1000, "mode": "adaptive"}}}, +} + +def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict: + if storage_options is None: + storage_options = {} + if cloud_provider in _DEFAULT_STORAGE_OPTIONS: + return {**_DEFAULT_STORAGE_OPTIONS[cloud_provider], **storage_options} + return storage_options + class FsspecDownloader(Downloader): def __init__( @@ -255,6 +266,7 @@ def __init__( ): remote_dir = remote_dir.replace("local:", "") self.is_local = False + storage_options = get_complete_storage_options(cloud_provider, storage_options) super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) self.cloud_provider = cloud_provider self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 @@ -281,7 +293,7 @@ def does_file_exist( ) -> bool: if cloud_provider is None: cloud_provider = get_cloud_provider(remote_filepath) - + storage_options = get_complete_storage_options(cloud_provider, storage_options) fs = fsspec.filesystem(cloud_provider, **storage_options) return fs.exists(remote_filepath) @@ -295,7 +307,7 @@ def list_directory( """Returns a list of filenames in a remote directory.""" if cloud_provider is None: cloud_provider = get_cloud_provider(remote_directory) - + storage_options = get_complete_storage_options(cloud_provider, storage_options) fs = fsspec.filesystem(cloud_provider, **storage_options) return fs.ls(remote_directory, detail=detail) # just return the filenames @@ -309,6 +321,7 @@ def download_file_or_directory(remote_filepath: str, local_filepath: str, storag return try: with FileLock(local_filepath + ".lock", timeout=3): + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.get(remote_filepath, local_filepath, recursive=True) except Timeout: @@ -321,6 +334,7 @@ def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_ try: with FileLock(local_filepath + ".lock", timeout=3): fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.put(local_filepath, remote_filepath, recursive=True) except Timeout: @@ -333,6 +347,7 @@ def copy_file_or_directory( ) -> None: """Copy a file from src to target on the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath_src) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) @@ -340,6 +355,7 @@ def copy_file_or_directory( def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: """Remove a file from the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.rm(remote_filepath, recursive=True) From 69fb43d3702aa100caca27c7286dcaa96f7d1d23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 07:59:50 +0000 Subject: [PATCH 31/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/downloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 046361ba..093c3451 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -244,9 +244,10 @@ def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> Non } _DEFAULT_STORAGE_OPTIONS = { - "s3":{"config_kwargs": {"retries":{"max_attempts": 1000, "mode": "adaptive"}}}, + "s3": {"config_kwargs": {"retries": {"max_attempts": 1000, "mode": "adaptive"}}}, } + def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict: if storage_options is None: storage_options = {} From b49a1265b47ac51d3b635a748f33341c65b859b9 Mon Sep 17 00:00:00 2001 From: deependu Date: Fri, 6 Sep 2024 14:32:50 +0530 Subject: [PATCH 32/43] raise error if cloud is not supported --- src/litdata/streaming/resolver.py | 10 ++++++++-- tests/streaming/test_resolver.py | 8 ++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index b4035851..bffea8f2 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -55,8 +55,14 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: assert isinstance(dir_path, str) cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS - if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): - return Dir(path=None, url=dir_path) + dir_scheme = parse.urlparse(dir_path).scheme + if bool(dir_scheme): + if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): + return Dir(path=None, url=dir_path) + raise ValueError( + f"The provided dir_path `{dir_path}` is not supported.", + f" HINT: Only the following cloud providers are supported: {_SUPPORTED_CLOUD_PROVIDERS}.", + ) if dir_path.startswith("local:"): return Dir(path=None, url=dir_path) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 630411f3..7e71de88 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -15,6 +15,7 @@ V1ListClustersResponse, V1ListDataConnectionsResponse, ) + from litdata.streaming import resolver @@ -367,3 +368,10 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch): link.symlink_to(src) assert link.resolve() == src assert resolver._resolve_dir(str(link)).path == str(src) + +def test_resolve_dir_unsupported_cloud_provider(monkeypatch, tmp_path): + """Test that the unsupported cloud provider is handled correctly.""" + + test_dir = "some-random-cloud-provider://some-random-bucket" + with pytest.raises(ValueError, match="The provided dir_path"): + resolver._resolve_dir(test_dir) From dbe8b0eb397f8cd32010b566e1fb7c0382048a72 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 09:03:20 +0000 Subject: [PATCH 33/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 7e71de88..0004c17e 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -15,7 +15,6 @@ V1ListClustersResponse, V1ListDataConnectionsResponse, ) - from litdata.streaming import resolver @@ -369,6 +368,7 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch): assert link.resolve() == src assert resolver._resolve_dir(str(link)).path == str(src) + def test_resolve_dir_unsupported_cloud_provider(monkeypatch, tmp_path): """Test that the unsupported cloud provider is handled correctly.""" From 5a81f04e719010c39e36f3d396d7cc64d9e44edc Mon Sep 17 00:00:00 2001 From: deependu Date: Fri, 6 Sep 2024 14:51:21 +0530 Subject: [PATCH 34/43] update --- src/litdata/streaming/resolver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index bffea8f2..0568e736 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -57,6 +57,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS dir_scheme = parse.urlparse(dir_path).scheme if bool(dir_scheme): + print("="*80) + print(f"{dir_scheme=}") + print("="*80) if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): return Dir(path=None, url=dir_path) raise ValueError( From 848484a67735ef1b22c4957f773c6502b58e1847 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 09:21:37 +0000 Subject: [PATCH 35/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/resolver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 0568e736..fead445c 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -57,9 +57,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS dir_scheme = parse.urlparse(dir_path).scheme if bool(dir_scheme): - print("="*80) + print("=" * 80) print(f"{dir_scheme=}") - print("="*80) + print("=" * 80) if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): return Dir(path=None, url=dir_path) raise ValueError( From 4d62fddefa702951491b23014b565b91c8aa501d Mon Sep 17 00:00:00 2001 From: deependu Date: Fri, 6 Sep 2024 15:05:23 +0530 Subject: [PATCH 36/43] fix windows error related to urllib parse scheme --- src/litdata/streaming/resolver.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index fead445c..3d292755 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -56,10 +56,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS dir_scheme = parse.urlparse(dir_path).scheme - if bool(dir_scheme): - print("=" * 80) - print(f"{dir_scheme=}") - print("=" * 80) + if bool(dir_scheme) and dir_scheme not in ["c", "d", "e" , "f"]: # prevent windows `c:\\` and `d:\\` if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): return Dir(path=None, url=dir_path) raise ValueError( From e544d0982dc02296e52c857dbfa1fcff6d65cd39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 09:37:10 +0000 Subject: [PATCH 37/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 3d292755..3c2bc4a3 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -56,7 +56,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS dir_scheme = parse.urlparse(dir_path).scheme - if bool(dir_scheme) and dir_scheme not in ["c", "d", "e" , "f"]: # prevent windows `c:\\` and `d:\\` + if bool(dir_scheme) and dir_scheme not in ["c", "d", "e", "f"]: # prevent windows `c:\\` and `d:\\` if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): return Dir(path=None, url=dir_path) raise ValueError( From e68076d8d05f8b20cc3828598751ccf6b24d7885 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 18 Sep 2024 10:03:13 +0530 Subject: [PATCH 38/43] cleanup commented code --- src/litdata/streaming/client.py | 70 --------------- src/litdata/streaming/downloader.py | 130 ---------------------------- tests/streaming/test_client.py | 97 --------------------- tests/streaming/test_downloader.py | 8 -- 4 files changed, 305 deletions(-) delete mode 100644 src/litdata/streaming/client.py delete mode 100644 tests/streaming/test_client.py diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py deleted file mode 100644 index 24cc431c..00000000 --- a/src/litdata/streaming/client.py +++ /dev/null @@ -1,70 +0,0 @@ -# # Copyright The Lightning AI team. -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. - -# import os -# from time import time -# from typing import Any, Dict, Optional - -# import boto3 -# import botocore -# from botocore.credentials import InstanceMetadataProvider -# from botocore.utils import InstanceMetadataFetcher - -# from litdata.constants import _IS_IN_STUDIO - - -# class S3Client: -# # TODO: Generalize to support more cloud providers. - -# def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: -# self._refetch_interval = refetch_interval -# self._last_time: Optional[float] = None -# self._client: Optional[Any] = None -# self._storage_options: dict = storage_options or {} - -# def _create_client(self) -> None: -# has_shared_credentials_file = ( -# os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" -# ) - -# if has_shared_credentials_file or not _IS_IN_STUDIO: -# self._client = boto3.client( -# "s3", -# **{ -# "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), -# **self._storage_options, -# }, -# ) -# else: -# provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) -# credentials = provider.load() -# self._client = boto3.client( -# "s3", -# aws_access_key_id=credentials.access_key, -# aws_secret_access_key=credentials.secret_key, -# aws_session_token=credentials.token, -# config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), -# ) - -# @property -# def client(self) -> Any: -# if self._client is None: -# self._create_client() -# self._last_time = time() - -# # Re-generate credentials for EC2 -# if self._last_time is None or (time() - self._last_time) > self._refetch_interval: -# self._create_client() -# self._last_time = time() - -# return self._client diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 093c3451..463ab576 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -53,127 +53,6 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: pass -# class S3Downloader(Downloader): -# def __init__( -# self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} -# ): -# super().__init__("s3", remote_dir, cache_dir, chunks, storage_options) -# self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - -# if not self._s5cmd_available: -# self._client = S3Client(storage_options=self._storage_options) - -# def download_file(self, remote_filepath: str, local_filepath: str) -> None: -# obj = parse.urlparse(remote_filepath) - -# if obj.scheme != "s3": -# raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") - -# if os.path.exists(local_filepath): -# return - -# try: -# with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): -# if self._s5cmd_available: -# proc = subprocess.Popen( -# f"s5cmd cp {remote_filepath} {local_filepath}", -# shell=True, -# stdout=subprocess.PIPE, -# ) -# proc.wait() -# else: -# from boto3.s3.transfer import TransferConfig - -# extra_args: Dict[str, Any] = {} - -# # try: -# # with FileLock(local_filepath + ".lock", timeout=1): -# if not os.path.exists(local_filepath): -# # Issue: https://github.com/boto/boto3/issues/3113 -# self._client.client.download_file( -# obj.netloc, -# obj.path.lstrip("/"), -# local_filepath, -# ExtraArgs=extra_args, -# Config=TransferConfig(use_threads=False), -# ) -# except Timeout: -# # another process is responsible to download that file, continue -# pass - - -# class GCPDownloader(Downloader): -# def __init__( -# self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} -# ): -# if not _GOOGLE_STORAGE_AVAILABLE: -# raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) - -# super().__init__("gs", remote_dir, cache_dir, chunks, storage_options) - -# def download_file(self, remote_filepath: str, local_filepath: str) -> None: -# from google.cloud import storage - -# obj = parse.urlparse(remote_filepath) - -# if obj.scheme != "gs": -# raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for {remote_filepath}") - -# if os.path.exists(local_filepath): -# return - -# try: -# with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): -# bucket_name = obj.netloc -# key = obj.path -# # Remove the leading "/": -# if key[0] == "/": -# key = key[1:] - -# client = storage.Client(**self._storage_options) -# bucket = client.bucket(bucket_name) -# blob = bucket.blob(key) -# blob.download_to_filename(local_filepath) -# except Timeout: -# # another process is responsible to download that file, continue -# pass - - -# class AzureDownloader(Downloader): -# def __init__( -# self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} -# ): -# if not _AZURE_STORAGE_AVAILABLE: -# raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) - -# super().__init__("abfs", remote_dir, cache_dir, chunks, storage_options) - -# def download_file(self, remote_filepath: str, local_filepath: str) -> None: -# from azure.storage.blob import BlobServiceClient - -# obj = parse.urlparse(remote_filepath) - -# if obj.scheme != "azure": -# raise ValueError( -# f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for {remote_filepath}" -# ) - -# if os.path.exists(local_filepath): -# return - -# try: -# with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): -# service = BlobServiceClient(**self._storage_options) -# blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/")) -# with open(local_filepath, "wb") as download_file: -# blob_data = blob_client.download_blob() -# blob_data.readinto(download_file) - -# except Timeout: -# # another process is responsible to download that file, continue -# pass - - class LocalDownloader(Downloader): def download_file(self, remote_filepath: str, local_filepath: str) -> None: if not os.path.exists(remote_filepath): @@ -198,15 +77,6 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: super().download_file(remote_filepath, local_filepath) -# _DOWNLOADERS = { -# "s3://": S3Downloader, -# "gs://": GCPDownloader, -# "azure://": AzureDownloader, -# "local:": LocalDownloaderWithCache, -# "": LocalDownloader, -# } - - def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: _s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py deleted file mode 100644 index 4f06a9ce..00000000 --- a/tests/streaming/test_client.py +++ /dev/null @@ -1,97 +0,0 @@ -# import sys -# from time import sleep, time -# from unittest import mock - -# import pytest -# from litdata.streaming import client - - -# def test_s3_client_with_storage_options(monkeypatch): -# boto3 = mock.MagicMock() -# monkeypatch.setattr(client, "boto3", boto3) - -# botocore = mock.MagicMock() -# monkeypatch.setattr(client, "botocore", botocore) - -# storage_options = { -# "region_name": "us-west-2", -# "endpoint_url": "https://custom.endpoint", -# "config": botocore.config.Config(retries={"max_attempts": 100}), -# } -# s3_client = client.S3Client(storage_options=storage_options) - -# assert s3_client.client - -# boto3.client.assert_called_with( -# "s3", -# region_name="us-west-2", -# endpoint_url="https://custom.endpoint", -# config=botocore.config.Config(retries={"max_attempts": 100}), -# ) - -# s3_client = client.S3Client() - -# assert s3_client.client - -# boto3.client.assert_called_with( -# "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) -# ) - - -# def test_s3_client_without_cloud_space_id(monkeypatch): -# boto3 = mock.MagicMock() -# monkeypatch.setattr(client, "boto3", boto3) - -# botocore = mock.MagicMock() -# monkeypatch.setattr(client, "botocore", botocore) - -# instance_metadata_provider = mock.MagicMock() -# monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - -# instance_metadata_fetcher = mock.MagicMock() -# monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - -# s3 = client.S3Client(1) -# assert s3.client -# assert s3.client -# assert s3.client -# assert s3.client -# assert s3.client - -# boto3.client.assert_called_once() - - -# @pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") -# @pytest.mark.parametrize("use_shared_credentials", [False, True, None]) -# def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): -# boto3 = mock.MagicMock() -# monkeypatch.setattr(client, "boto3", boto3) - -# botocore = mock.MagicMock() -# monkeypatch.setattr(client, "botocore", botocore) - -# if isinstance(use_shared_credentials, bool): -# monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") -# monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") -# monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") - -# instance_metadata_provider = mock.MagicMock() -# monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - -# instance_metadata_fetcher = mock.MagicMock() -# monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - -# s3 = client.S3Client(1) -# assert s3.client -# assert s3.client -# boto3.client.assert_called_once() -# sleep(1 - (time() - s3._last_time)) -# assert s3.client -# assert s3.client -# assert len(boto3.client._mock_mock_calls) == 6 -# sleep(1 - (time() - s3._last_time)) -# assert s3.client -# assert s3.client -# assert len(boto3.client._mock_mock_calls) == 9 - -# assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3 diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 6bee221d..97368c0c 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -6,14 +6,6 @@ shutil, ) -# def test_s3_downloader_fast(tmpdir, monkeypatch): -# monkeypatch.setattr(os, "system", MagicMock(return_value=0)) -# popen_mock = MagicMock() -# monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) -# downloader = S3Downloader(tmpdir, tmpdir, []) -# downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) -# popen_mock.wait.assert_called() - def test_download_with_cache(tmpdir, monkeypatch): # Create a file to download/cache From 2036e37bc989efadbb208d5784ac067be1d8266c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 04:37:53 +0000 Subject: [PATCH 39/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/processing/data_processor.py | 1 - src/litdata/streaming/downloader.py | 2 -- tests/streaming/test_resolver.py | 1 - 3 files changed, 4 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index c155cce5..c5124c99 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -135,7 +135,6 @@ def _download_data_target( input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Optional[Dict] = {} ) -> None: """This function is used to download data from a remote directory to a cache directory to optimise reading.""" - while True: # 2. Fetch from the queue r: Optional[Tuple[int, List[str]]] = queue_in.get() diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 212955b9..463ab576 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -16,9 +16,7 @@ import shutil import subprocess from abc import ABC - from typing import Any, Dict, List, Optional, Union -from contextlib import suppress from urllib import parse import fsspec diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index f96354ac..36623d81 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -371,7 +371,6 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch): def test_resolve_dir_unsupported_cloud_provider(monkeypatch, tmp_path): """Test that the unsupported cloud provider is handled correctly.""" - test_dir = "some-random-cloud-provider://some-random-bucket" with pytest.raises(ValueError, match="The provided dir_path"): resolver._resolve_dir(test_dir) From e230cebd8f5dc30f9f5840a8bf8c4d4685b84f92 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 18 Sep 2024 10:09:42 +0530 Subject: [PATCH 40/43] update --- src/litdata/processing/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 4285e0c6..dab4b79d 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -546,6 +546,7 @@ def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Opti Args: input_dirs: A list of directories pointing to the existing optimized datasets. output_dir: The directory where the merged dataset would be stored. + storage_options: A dictionary of storage options to be passed to the fsspec library. """ if len(input_dirs) == 0: From e60f9aedf0f887b05ca137630e68d5dd01dac486 Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 18 Sep 2024 11:38:39 +0530 Subject: [PATCH 41/43] readme updated --- README.md | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index bab69ed9..470cace5 100644 --- a/README.md +++ b/README.md @@ -217,9 +217,8 @@ Additionally, you can inject client connection settings for [S3](https://boto3.a from litdata import StreamingDataset storage_options = { - "endpoint_url": "your_endpoint_url", - "aws_access_key_id": "your_access_key_id", - "aws_secret_access_key": "your_secret_access_key", + "key": "your_access_key_id", + "secret": "your_secret_access_key", } dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) @@ -264,7 +263,7 @@ for batch in val_dataloader:   -The StreamingDataset supports reading optimized datasets from common cloud providers. +The StreamingDataset supports reading optimized datasets from common cloud providers. ```python import os @@ -272,25 +271,39 @@ import litdata as ld # Read data from AWS S3 aws_storage_options={ - "AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'], - "AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'], + "key": os.environ['AWS_ACCESS_KEY_ID'], + "secret": os.environ['AWS_SECRET_ACCESS_KEY'], } dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) # Read data from GCS gcp_storage_options={ - "project": os.environ['PROJECT_ID'], + "token": { + # dumped from cat ~/.config/gcloud/application_default_credentials.json + "account": "", + "client_id": "your_client_id", + "client_secret": "your_client_secret", + "quota_project_id": "your_quota_project_id", + "refresh_token": "your_refresh_token", + "type": "authorized_user", + "universe_domain": "googleapis.com", + } } dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options) # Read data from Azure azure_storage_options={ - "account_url": f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net", - "credential": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] + "account_name": "azure_account_name", + "account_key": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] } dataset = ld.StreamingDataset("azure://my-bucket/my-data", storage_options=azure_storage_options) ``` +- For more details on which storage options are supported, please refer to: + - [AWS S3 storage options](https://github.com/fsspec/s3fs/blob/main/s3fs/core.py#L176) + - [GCS storage options](https://github.com/fsspec/gcsfs/blob/main/gcsfs/core.py#L154) + - [Azure storage options](https://github.com/fsspec/adlfs/blob/main/adlfs/spec.py#L124) +
From feb5d48f572e51c46c76e8fcea37c92084ad5c0e Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 18 Sep 2024 12:09:26 +0530 Subject: [PATCH 42/43] increase test_dataset_resume_on_future_chunk timeout time to 120 seconds --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index ef93021c..b87f9ffd 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -889,7 +889,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) -@pytest.mark.timeout(60) +@pytest.mark.timeout(120) @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" From b5ec077c7309c298cb5fa8dd9e17f2ae17e71bbb Mon Sep 17 00:00:00 2001 From: deependu Date: Wed, 18 Sep 2024 13:07:27 +0530 Subject: [PATCH 43/43] update --- src/litdata/processing/data_processor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index c5124c99..fae806b5 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -1140,7 +1140,11 @@ def run(self, data_recipe: DataRecipe) -> None: # Exit early if all the workers are done. # This means there were some kinda of errors. if all(not w.is_alive() for w in self.workers): - raise RuntimeError("One of the worker has failed") + try: + error = self.error_queue.get(timeout=0.001) + self._exit_on_error(error) + except Empty: + break if _TQDM_AVAILABLE: pbar.close()