diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93cb15a1..104ebe5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,13 +46,6 @@ repos: additional_dependencies: [tomli] #args: ["--write-changes"] # uncomment if you want to get automatic fixing - - repo: https://github.com/PyCQA/docformatter - rev: v1.7.5 - hooks: - - id: docformatter - additional_dependencies: [tomli] - args: ["--in-place"] - - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.2 hooks: diff --git a/README.md b/README.md index bab69ed9..470cace5 100644 --- a/README.md +++ b/README.md @@ -217,9 +217,8 @@ Additionally, you can inject client connection settings for [S3](https://boto3.a from litdata import StreamingDataset storage_options = { - "endpoint_url": "your_endpoint_url", - "aws_access_key_id": "your_access_key_id", - "aws_secret_access_key": "your_secret_access_key", + "key": "your_access_key_id", + "secret": "your_secret_access_key", } dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) @@ -264,7 +263,7 @@ for batch in val_dataloader:   -The StreamingDataset supports reading optimized datasets from common cloud providers. +The StreamingDataset supports reading optimized datasets from common cloud providers. ```python import os @@ -272,25 +271,39 @@ import litdata as ld # Read data from AWS S3 aws_storage_options={ - "AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'], - "AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'], + "key": os.environ['AWS_ACCESS_KEY_ID'], + "secret": os.environ['AWS_SECRET_ACCESS_KEY'], } dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) # Read data from GCS gcp_storage_options={ - "project": os.environ['PROJECT_ID'], + "token": { + # dumped from cat ~/.config/gcloud/application_default_credentials.json + "account": "", + "client_id": "your_client_id", + "client_secret": "your_client_secret", + "quota_project_id": "your_quota_project_id", + "refresh_token": "your_refresh_token", + "type": "authorized_user", + "universe_domain": "googleapis.com", + } } dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options) # Read data from Azure azure_storage_options={ - "account_url": f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net", - "credential": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] + "account_name": "azure_account_name", + "account_key": os.environ['AZURE_ACCOUNT_ACCESS_KEY'] } dataset = ld.StreamingDataset("azure://my-bucket/my-data", storage_options=azure_storage_options) ``` +- For more details on which storage options are supported, please refer to: + - [AWS S3 storage options](https://github.com/fsspec/s3fs/blob/main/s3fs/core.py#L176) + - [GCS storage options](https://github.com/fsspec/gcsfs/blob/main/gcsfs/core.py#L154) + - [Azure storage options](https://github.com/fsspec/adlfs/blob/main/adlfs/spec.py#L124) +
diff --git a/examples/multi_modal/create_labelencoder.py b/examples/multi_modal/create_labelencoder.py index 657318bc..83549924 100644 --- a/examples/multi_modal/create_labelencoder.py +++ b/examples/multi_modal/create_labelencoder.py @@ -3,11 +3,7 @@ def create_labelencoder(): - """ - Create a label encoder - Returns: - - """ + """Create a label encoder.""" data = ["Cancelation", "IBAN Change", "Damage Report"] # Create an instance of LabelEncoder label_encoder = LabelEncoder() diff --git a/examples/multi_modal/dataloader.py b/examples/multi_modal/dataloader.py index df28dc2e..c33a3cf7 100644 --- a/examples/multi_modal/dataloader.py +++ b/examples/multi_modal/dataloader.py @@ -29,15 +29,12 @@ def __init__(self): self.hyperparameters = HYPERPARAMETERS def load_labelencoder(self): - """ - Function to load the label encoder from s3 - Returns: - """ + """Function to load the label encoder from s3.""" return joblib.load(self.hyperparameters["label_encoder_name"]) def load_tokenizer(self): - """ - load the tokenizer files and the pre-training model path from s3 specified in the hyperparameters + """Loads the tokenizer files and the pre-training model path from s3 specified in the hyperparameters. + Returns: tokenizer """ # Load Bert tokenizer @@ -62,13 +59,10 @@ def __init__(self, input_dir: Union[str, Any], hyperparameters: Union[dict, Any] self.labelencoder = EC.load_labelencoder() def tokenize_data(self, tokenizer, texts, max_length: int): - """ - Tokenize the text - Args: - tokenizer: - texts: - max_length: - Returns: input_ids, attention_masks + """Tokenize the text. + + Returns: input_ids, attention_masks. + """ encoded_text = tokenizer( texts, @@ -98,11 +92,10 @@ class MixedDataModule(pl.LightningDataModule): """Own DataModule form the pytorch lightning DataModule.""" def __init__(self, hyperparameters: dict): - """ - Init if the Data Module + """Initialize if the Data Module. + Args: - data_path: dataframe with the data - hyperparameters: Hyperparameters + hyperparameters: Hyperparameters. """ super().__init__() self.hyperparameters = hyperparameters @@ -130,10 +123,11 @@ def __init__(self, hyperparameters: dict): ) def train_dataloader(self) -> DataLoader: - """ - Define the training dataloader + """Define the training dataloader. + Returns: - training dataloader + training dataloader. + """ dataset_train = DocumentClassificationDataset( hyperparameters=self.hyperparameters, @@ -150,10 +144,10 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - """ - Define the validation dataloader + """Defines the validation dataloader. + Returns: - validation dataloader + validation dataloader. """ dataset_val = DocumentClassificationDataset( hyperparameters=self.hyperparameters, @@ -169,8 +163,8 @@ def val_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: - """ - Define the test dataloader + """Defines the test dataloader. + Returns: test dataloader """ diff --git a/examples/multi_modal/loop.py b/examples/multi_modal/loop.py index 6104f6e3..2c2de659 100644 --- a/examples/multi_modal/loop.py +++ b/examples/multi_modal/loop.py @@ -77,7 +77,6 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report): mode: train, test or val report_confusion_matrix: sklearn confusion matrix report: sklear classification report - Returns: """ df_cm = pd.DataFrame(report_confusion_matrix) @@ -87,17 +86,7 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report): logger.info("Confusion Matrix and Classification report are saved.") def save_test_evaluations(self, model_dir, mode, y_pred, y_true, confis, numerical_id_): - """ - Save a pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset - Args: - model_dir: - mode: - y_pred: - y_true: - confis: - numerical_id_: - Returns: - """ + """Save pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset.""" df_test = pd.DataFrame() df_test["pred"] = y_pred df_test["confidence"] = confis.max(axis=1) @@ -151,8 +140,7 @@ def forward( """Forward path, calculate the computational graph in the forward direction. Used for train, test and val. - Args: - y: tensor with text data as tokens + Returns: computional graph @@ -160,34 +148,29 @@ def forward( return self.module(x, y, z) def training_step(self, batch: Dict[str, torch.Tensor]) -> Dict: - """ - Call the eval share for training - Args: - batch: tensor + """Call the eval share for training. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ return self._shared_eval_step(batch, "train") def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict: - """ - Call the eval share for validation - Args: - batch: - batch_idx: + """Call the eval share for validation. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ return self._shared_eval_step(batch, "val") def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict: - """ - Call the eval share for test - Args: - batch: - batch_idx: + """Call the eval share for test. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ ret = self._shared_eval_step(batch, "test") self.pred_list.append(ret) @@ -199,6 +182,7 @@ def _shared_eval_step(self, batch: Dict[str, torch.Tensor], mode: str) -> Dict: Args: batch: tensor mode: train, test or val + Returns: dict with loss, outputs and ground_truth @@ -227,14 +211,8 @@ def _shared_eval_step(self, batch: Dict[str, torch.Tensor], mode: str) -> Dict: return {"outputs": out, "loss": loss, "ground_truth": ground_truth, "numerical_id": numerical_id} - def _epoch_end(self, mode: str): - """ - Calculate loss and metricies at end of epoch - Args: - mode: - Returns: - None - """ + def _epoch_end(self, mode: str) -> None: + """Calculate loss and metrics at end of epoch.""" if mode == "val": output = self.val_metrics.compute() self.log_dict(output) @@ -249,14 +227,7 @@ def _epoch_end(self, mode: str): self.test_metrics.reset() def predict(self, batch: Dict[str, torch.Tensor], batch_idx: int = 0, dataloader_idx: int = 0) -> torch.Tensor: - """Model prediction without softmax and argmax to predict class label. - - Args: - outputs: - Returns: - None - - """ + """Model prediction without softmax and argmax to predict class label.""" self.eval() with torch.no_grad(): ids = batch["ID"] @@ -265,51 +236,30 @@ def predict(self, batch: Dict[str, torch.Tensor], batch_idx: int = 0, dataloader return self.forward(ids, atts, img) def on_test_epoch_end(self) -> None: - """ - Calculate the metrics at the end of epoch for test step - Args: - outputs: - Returns: - None - """ + """Calculate the metrics at the end of epoch for test step.""" self._epoch_end("test") - def on_validation_epoch_end(self): - """ - Calculate the metrics at the end of epoch for val step - Args: - outputs: - Returns: - None - """ + def on_validation_epoch_end(self) -> None: + """Calculate the metrics at the end of epoch for val step.""" self._epoch_end("val") - def on_train_epoch_end(self): - """ - Calculate the metrics at the end of epoch for train step - Args: - outputs: - Returns: - None - """ + def on_train_epoch_end(self) -> None: + """Calculate the metrics at the end of epoch for train step.""" self._epoch_end("train") def configure_optimizers(self) -> Any: - """ - Configure the optimizer + """Configure the optimizer. + Returns: optimizer + """ optimizer = AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.hyperparameters["weight_decay"]) scheduler = StepLR(optimizer, step_size=1, gamma=0.1) return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] def configure_callbacks(self) -> Union[Sequence[pl.pytorch.Callback], pl.pytorch.Callback]: - """Configure Early stopping or Model Checkpointing. - - Returns: - - """ + """Configure Early stopping or Model Checkpointing.""" early_stop = EarlyStopping( monitor="val_MulticlassAccuracy", patience=self.hyperparameters["patience"], mode="max" ) diff --git a/examples/multi_modal/model_arc.py b/examples/multi_modal/model_arc.py index 4ea8cddb..8eb57ca4 100644 --- a/examples/multi_modal/model_arc.py +++ b/examples/multi_modal/model_arc.py @@ -41,9 +41,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): """Forward path, calculate the computational graph in the forward direction. Used for train, test and val. - Args: - input_ids - attention_mask + Returns: computional graph @@ -72,9 +70,9 @@ def __init__(self, endpoint_mode: bool, hyperparameters: dict): self.dropout = nn.Dropout(self.hyperparameters["dropout"]) def get_bert_model(self): - """ - Load the pre trained bert model weights - Returns: model + """Load the pre-trained bert model weights. + + Returns: model. """ model = BertModel.from_pretrained("bert-base-cased") return BertClassifier(model) @@ -89,9 +87,9 @@ def forward( validation. Args: - x (torch.Tensor): Tensor with id token - y (torch.Tensor): Tensor with attention tokens. - z (torch.Tensor): Tensor with image. + x: Tensor with id token + y: Tensor with attention tokens. + z: Tensor with image. Returns: torch.Tensor: The output tensor representing the computational graph. diff --git a/pyproject.toml b/pyproject.toml index b78d1cf0..b0883f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,12 @@ ignore-words-list = "cancelation" [tool.ruff] line-length = 120 target-version = "py38" +# Exclude a variety of commonly ignored directories. +exclude = [ + ".git", + "docs", + "src/litdata/utilities/_pytree.py", +] # Enable Pyflakes `E` and `F` codes by default. lint.select = [ "E", "W", # see: https://pypi.org/project/pycodestyle @@ -53,45 +59,54 @@ lint.select = [ "UP", # see: pyupgrade ] lint.extend-select = [ + "D", "I", # see: isort "C4", # see: https://pypi.org/project/flake8-comprehensions "SIM", # see: https://pypi.org/project/flake8-simplify "RET", # see: https://pypi.org/project/flake8-return "PT", # see: https://pypi.org/project/flake8-pytest-style "NPY201", # see: https://docs.astral.sh/ruff/rules/numpy2-deprecation - "RUF100" # yesqa + "RUF100", # yesqa ] lint.ignore = [ "E731", # Do not assign a lambda expression, use a def "S101", # todo: Use of `assert` detected ] -# Exclude a variety of commonly ignored directories. -exclude = [ - ".git", - "docs", - "src/litdata/utilities/_pytree.py", -] lint.ignore-init-module-imports = true +# Unlike Flake8, default to a complexity level of 10. +lint.mccabe.max-complexity = 10 +# Use Google-style docstrings. +lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] -".actions/*" = ["S101", "S310"] -"setup.py" = ["S101", "SIM115"] +"setup.py" = ["D100", "SIM115"] "examples/**" = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D107", # Missing docstring in public module, class, method, function, package + "D205", # todo: 1 blank line required between summary line and description + "D401", "D404", # First line should be in imperative mood; try rephrasing "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] "src/**" = [ + "D100", # Missing docstring in public module + "D101", # todo: Missing docstring in public class + "D102", # todo: Missing docstring in public method + "D103", # todo: Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # todo: Missing docstring in magic method + "D107", # todo: Missing docstring in __init__ + "D205", # todo: 1 blank line required between summary line and description + "D401", "D404", # todo: First line should be in imperative mood; try rephrasing "S602", # todo: `subprocess` call with `shell=True` identified, security issue "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` "S607", # todo: Starting a process with a partial executable path "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. ] "tests/**" = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D107", # Missing docstring in public module, class, method, function, package + "D401", "D404", # First line should be in imperative mood; try rephrasing "S105", "S106", # todo: Possible hardcoded password: ... ] -[tool.ruff.lint.mccabe] -# Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 [tool.mypy] diff --git a/requirements.txt b/requirements.txt index 06a629a0..ec443722 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,7 @@ torch lightning-utilities filelock numpy -boto3 +# boto3 requests +fsspec +fsspec[s3] # aws s3 diff --git a/requirements/extras.txt b/requirements/extras.txt index 385e2e81..33d42446 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -5,3 +5,5 @@ pyarrow tqdm lightning-sdk ==0.1.17 # Must be pinned to ensure compatibility google-cloud-storage +fsspec[gs] # google cloud storage +fsspec[abfs] # azure blob diff --git a/src/litdata/__about__.py b/src/litdata/__about__.py index 9aee9d79..5bee1c26 100644 --- a/src/litdata/__about__.py +++ b/src/litdata/__about__.py @@ -14,7 +14,7 @@ import time -__version__ = "0.2.26" +__version__ = "0.2.27" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litdata/constants.py b/src/litdata/constants.py index a6a714c7..efe2e248 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -85,3 +85,4 @@ _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" _IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None)) _ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0"))) +_SUPPORTED_CLOUD_PROVIDERS = ["s3", "gs", "azure", "abfs"] diff --git a/src/litdata/imports.py b/src/litdata/imports.py index 3d415569..b4288ed9 100644 --- a/src/litdata/imports.py +++ b/src/litdata/imports.py @@ -97,7 +97,7 @@ def _check_requirement(self) -> None: self.message = f"Requirement {self.requirement!r} met" except Exception as ex: self.available = False - self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`" + self.message = f"{ex.__class__.__name__}: {ex}.\n HINT: Try running `pip install -U {self.requirement!r}`" requirement_contains_version_specifier = any(c in self.requirement for c in "=<>") if not requirement_contains_version_specifier or self.module is not None: module = self.requirement if self.module is None else self.module diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index a246c68b..9dfba498 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -32,8 +32,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse -import boto3 -import botocore import numpy as np import torch @@ -42,14 +40,21 @@ _ENABLE_STATUS, _INDEX_FILENAME, _IS_IN_STUDIO, + _SUPPORTED_CLOUD_PROVIDERS, _TQDM_AVAILABLE, ) from litdata.processing.readers import BaseReader, StreamingDataLoaderReader -from litdata.processing.utilities import _create_dataset, download_directory_from_S3, remove_uuid_from_filename +from litdata.processing.utilities import _create_dataset, remove_uuid_from_filename from litdata.streaming import Cache from litdata.streaming.cache import Dir -from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.downloader import ( + does_file_exist, + download_file_or_directory, + get_cloud_provider, + remove_file_or_directory, + upload_file_or_directory, +) from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import _resolve_dir from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads @@ -96,14 +101,22 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any: - """This function check.""" +def _wait_for_file_to_exist( + remote_filepath: str, sleep_time: int = 2, wait_for_count: int = 5, storage_options: Optional[Dict] = {} +) -> Any: + """This function check if a file exists on the remote storage. + + If not, it waits for a while and tries again. + + """ + cloud_provider = get_cloud_provider(remote_filepath) while True: try: - return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/")) - except botocore.exceptions.ClientError as e: - if "the HeadObject operation: Not Found" in str(e): + return does_file_exist(remote_filepath, cloud_provider, storage_options=storage_options) + except Exception as e: + if wait_for_count > 0: sleep(sleep_time) + wait_for_count -= 1 else: raise e @@ -118,10 +131,10 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: return -def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: +def _download_data_target( + input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Optional[Dict] = {} +) -> None: """This function is used to download data from a remote directory to a cache directory to optimise reading.""" - s3 = S3Client() - while True: # 2. Fetch from the queue r: Optional[Tuple[int, List[str]]] = queue_in.get() @@ -156,13 +169,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue obj = parse.urlparse(path) - if obj.scheme == "s3": + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: dirpath = os.path.dirname(local_path) os.makedirs(dirpath, exist_ok=True) - - with open(local_path, "wb") as f: - s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + download_file_or_directory(path, local_path, storage_options=storage_options) elif os.path.isfile(path): if not path.startswith("/teamspace/studios/this_studio"): @@ -176,7 +187,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: - """This function is used to delete files from the cache directory to minimise disk space.""" + """Delete files from the cache directory to minimise disk space.""" while True: # 1. Collect paths paths = queue_in.get() @@ -198,12 +209,13 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: os.remove(path) -def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: +def _upload_fn( + upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir, storage_options: Optional[Dict] = {} +) -> None: """This function is used to upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) - if obj.scheme == "s3": - s3 = S3Client() + is_remote = obj.scheme in _SUPPORTED_CLOUD_PROVIDERS while True: data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get() @@ -223,7 +235,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ if not local_filepath.startswith(cache_dir): local_filepath = os.path.join(cache_dir, local_filepath) - if obj.scheme == "s3": + if is_remote: try: output_filepath = str(obj.path).lstrip("/") @@ -235,12 +247,8 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = os.path.join(output_filepath, local_filepath.replace(tmpdir, "")[1:]) output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints - - s3.client.upload_file( - local_filepath, - obj.netloc, - output_filepath, - ) + remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + output_filepath + upload_file_or_directory(local_filepath, remote_filepath, storage_options=storage_options) except Exception as e: print(e) @@ -417,6 +425,7 @@ def __init__( checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None, checkpoint_next_index: Optional[int] = None, item_loader: Optional[BaseItemLoader] = None, + storage_options: Optional[Dict] = {}, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -451,6 +460,7 @@ def __init__( self.use_checkpoint: bool = use_checkpoint self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info self.checkpoint_next_index: Optional[int] = checkpoint_next_index + self.storage_options = storage_options def run(self) -> None: try: @@ -641,6 +651,7 @@ def _start_downloaders(self) -> None: self.cache_data_dir, to_download_queue, self.ready_to_process_queue, + self.storage_options, ), ) p.start() @@ -680,6 +691,7 @@ def _start_uploaders(self) -> None: self.remove_queue, self.cache_chunks_dir, self.output_dir, + self.storage_options, ), ) p.start() @@ -781,6 +793,7 @@ def __init__( chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, encryption: Optional[Encryption] = None, + storage_options: Optional[Dict] = {}, ): super().__init__() if chunk_size is not None and chunk_bytes is not None: @@ -790,6 +803,7 @@ def __init__( self.chunk_bytes = 1 << 26 if chunk_size is None and chunk_bytes is None else chunk_bytes self.compression = compression self.encryption = encryption + self.storage_options = storage_options @abstractmethod def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @@ -801,7 +815,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @abstractmethod def prepare_item(self, item_metadata: T) -> Any: - """The return of this `prepare_item` method is persisted in chunked binary files.""" + """Returns `prepare_item` method is persisted in chunked binary files.""" def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result: num_nodes = _get_num_nodes() @@ -846,7 +860,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul ) def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: Optional[int]) -> None: - """This method upload the index file to the remote cloud directory.""" + """Upload the index file to the remote cloud directory.""" if output_dir.path is None and output_dir.url is None: return @@ -856,10 +870,12 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra else: local_filepath = os.path.join(cache_dir, _INDEX_FILENAME) - if obj.scheme == "s3": - s3 = S3Client() - s3.client.upload_file( - local_filepath, obj.netloc, os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)) + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + remote_filepath = str(obj.scheme) + "://" + str(obj.netloc) + "/" + upload_file_or_directory( + local_filepath, + remote_filepath + os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)), + storage_options=self.storage_options, ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) @@ -877,11 +893,13 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra assert output_dir_path remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) - if obj.scheme == "s3": - obj = parse.urlparse(remote_filepath) - _wait_for_file_to_exist(s3, obj) - with open(node_index_filepath, "wb") as f: - s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + if obj.scheme in _SUPPORTED_CLOUD_PROVIDERS: + _wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options) + download_file_or_directory( + remote_filepath, + node_index_filepath, + storage_options=self.storage_options, + ) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(remote_filepath, node_index_filepath) @@ -922,11 +940,11 @@ def __init__( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Optional[Dict] = {}, ): - """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make - training faster. + """Provides an efficient way to process data across multiple machine into chunks to make training faster. - Arguments: + Args: input_dir: The path to where the input data are stored. output_dir: The path to where the output data are stored. num_workers: The number of worker threads to use. @@ -947,6 +965,7 @@ def __init__( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: The storage options used by the cloud provider. """ # spawn doesn't work in IPython @@ -983,6 +1002,7 @@ def __init__( self.item_loader = item_loader self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} + self.storage_options = storage_options if self.reader is not None and self.weights is not None: raise ValueError("Either the reader or the weights needs to be defined.") @@ -998,7 +1018,7 @@ def __init__( self.random_seed = random_seed def run(self, data_recipe: DataRecipe) -> None: - """The `DataProcessor.run(...)` method triggers the data recipe processing over your dataset.""" + """Triggers the data recipe processing over your dataset.""" if not isinstance(data_recipe, DataRecipe): raise ValueError("The provided value should be a data recipe.") if not self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): @@ -1139,7 +1159,11 @@ def run(self, data_recipe: DataRecipe) -> None: # This means there were some kinda of errors. # TODO: Check whether this is still required. if all(not w.is_alive() for w in self.workers): - raise RuntimeError("One of the worker has failed") + try: + error = self.error_queue.get(timeout=0.001) + self._exit_on_error(error) + except Empty: + break if _TQDM_AVAILABLE: pbar.close() @@ -1150,12 +1174,11 @@ def run(self, data_recipe: DataRecipe) -> None: if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO: from lightning_sdk.lightning_cloud.openapi import V1DatasetType + data_type = V1DatasetType.CHUNKED if isinstance(data_recipe, DataChunkRecipe) else V1DatasetType.TRANSFORMED _create_dataset( input_dir=self.input_dir.path, storage_dir=self.output_dir.path, - dataset_type=V1DatasetType.CHUNKED - if isinstance(data_recipe, DataChunkRecipe) - else V1DatasetType.TRANSFORMED, + dataset_type=data_type, empty=False, size=result.size, num_bytes=result.num_bytes, @@ -1202,6 +1225,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None, self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None, self.item_loader, + storage_options=self.storage_options, ) worker.start() workers.append(worker) @@ -1253,21 +1277,14 @@ def _cleanup_checkpoints(self) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" - - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc - s3 = boto3.resource("s3") - - checkpoint_prefix = os.path.join(prefix, ".checkpoints") - - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=checkpoint_prefix): - s3.Object(bucket_name, obj.key).delete() + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) + with suppress(FileNotFoundError): + remove_file_or_directory( + os.path.join(self.output_dir.url, ".checkpoints"), storage_options=self.storage_options + ) def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: if not self.use_checkpoint: @@ -1293,24 +1310,20 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - # TODO: Add support for all cloud providers - - s3 = S3Client() - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) # write config.json file to temp directory and upload it to s3 with tempfile.TemporaryDirectory() as temp_dir: temp_file_name = os.path.join(temp_dir, "config.json") with open(temp_file_name, "w") as f: json.dump(config, f) - s3.client.upload_file( + upload_file_or_directory( temp_file_name, - obj.netloc, - os.path.join(prefix, "config.json"), + os.path.join(self.output_dir.url, ".checkpoints", "config.json"), + storage_options=self.storage_options, ) except Exception as e: print(e) @@ -1361,26 +1374,25 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: obj = parse.urlparse(self.output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {self.output_dir.path}.") - - # TODO: Add support for all cloud providers - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + ".checkpoints/" - - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {self.output_dir.path}." + ) # download all the checkpoint files in tempdir and read them with tempfile.TemporaryDirectory() as temp_dir: - saved_file_dir = download_directory_from_S3(bucket_name, prefix, temp_dir) - - if not os.path.exists(os.path.join(saved_file_dir, "config.json")): + try: + download_file_or_directory( + os.path.join(self.output_dir.url, ".checkpoints/"), temp_dir, storage_options=self.storage_options + ) + except FileNotFoundError: + return + if not os.path.exists(os.path.join(temp_dir, "config.json")): # if the config.json file doesn't exist, we don't have any checkpoint saved return # read the config.json file - with open(os.path.join(saved_file_dir, "config.json")) as f: + with open(os.path.join(temp_dir, "config.json")) as f: config = json.load(f) if config["num_workers"] != self.num_workers: @@ -1394,11 +1406,11 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: checkpoint_file_names = [f"checkpoint-{worker_idx}.json" for worker_idx in range(self.num_workers)] for i, checkpoint_file_name in enumerate(checkpoint_file_names): - if not os.path.exists(os.path.join(saved_file_dir, checkpoint_file_name)): + if not os.path.exists(os.path.join(temp_dir, checkpoint_file_name)): # if the checkpoint file doesn't exist, we don't have any checkpoint saved for this worker continue - with open(os.path.join(saved_file_dir, checkpoint_file_name)) as f: + with open(os.path.join(temp_dir, checkpoint_file_name)) as f: checkpoint = json.load(f) self.checkpoint_chunks_info[i] = checkpoint["chunks"] @@ -1407,6 +1419,7 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: def in_notebook() -> bool: - """Returns ``True`` if the module is running in IPython kernel, ``False`` if in IPython shell or other Python - shell.""" + """Returns ``True`` if the module is running in IPython kernel, ``False`` if in IPython or other Python + shell. + """ return "ipykernel" in sys.modules diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 8edc2319..e83c8c1b 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -27,7 +27,7 @@ import torch -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -36,8 +36,8 @@ optimize_dns_context, read_index_file_content, ) -from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.downloader import copy_file_or_directory, upload_file_or_directory from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.resolver import ( Dir, @@ -53,7 +53,7 @@ def _is_remote_file(path: str) -> bool: obj = parse.urlparse(path) - return obj.scheme in ["s3", "gcs"] + return obj.scheme in _SUPPORTED_CLOUD_PROVIDERS def _get_indexed_paths(data: Any) -> Dict[int, str]: @@ -151,8 +151,15 @@ def __init__( compression: Optional[str], encryption: Optional[Encryption] = None, existing_index: Optional[Dict[str, Any]] = None, + storage_options: Optional[Dict] = {}, ): - super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption) + super().__init__( + chunk_size=chunk_size, + chunk_bytes=chunk_bytes, + compression=compression, + encryption=encryption, + storage_options=storage_options, + ) self._fn = fn self._inputs = inputs self.is_generator = False @@ -180,7 +187,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> Any: return self._inputs def prepare_item(self, item_metadata: Any) -> Any: - """This method is overridden dynamically.""" + """Being overridden dynamically.""" def map( @@ -199,10 +206,11 @@ def map( error_when_not_empty: bool = False, reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, + storage_options: Optional[Dict] = {}, ) -> None: - """This function maps a callable over a collection of inputs, possibly in a distributed way. + """Maps a callable over a collection of inputs, possibly in a distributed way. - Arguments: + Args: fn: A function to be executed over each input element inputs: A sequence of input to be processed by the `fn` function, or a streaming dataloader. output_dir: The folder where the processed data should be stored. @@ -218,7 +226,9 @@ def map( reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. error_when_not_empty: Whether we should error if the output folder isn't empty. + reader: The reader to use when reading the data. By default, it uses the `BaseReader`. batch_size: Group the inputs into batches of batch_size length. + storage_options: The storage options used by the cloud provider. """ if isinstance(inputs, StreamingDataLoader) and batch_size is not None: @@ -253,11 +263,11 @@ def map( if _output_dir.url and "cloudspaces" in _output_dir.url: raise ValueError( f"The provided `output_dir` isn't valid. Found {_output_dir.path if _output_dir else None}." - " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." + "\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) if error_when_not_empty: - _assert_dir_is_empty(_output_dir) + _assert_dir_is_empty(_output_dir, storage_options=storage_options) if not isinstance(inputs, StreamingDataLoader): input_dir = input_dir or _get_input_dir(inputs) @@ -281,6 +291,7 @@ def map( reorder_files=reorder_files, weights=weights, reader=reader, + storage_options=storage_options, ) with optimize_dns_context(True): return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) @@ -314,10 +325,11 @@ def optimize( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Optional[Dict] = {}, ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. - Arguments: + Args: fn: A function to be executed over each input element. The function should return the data sample that corresponds to the input. Every invocation of the function should return a similar hierarchy of objects, where the object types and list sizes don't change. @@ -336,6 +348,7 @@ def optimize( machine: When doing remote execution, the machine to use. Only supported on https://lightning.ai/. num_downloaders: The number of downloaders per worker. num_uploaders: The numbers of uploaders per worker. + reader: The reader to use when reading the data. By default, it uses the `BaseReader`. reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. batch_size: Group the inputs into batches of batch_size length. @@ -347,6 +360,7 @@ def optimize( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: The storage options used by the cloud provider. """ if mode is not None and mode not in ["append", "overwrite"]: @@ -398,10 +412,12 @@ def optimize( if _output_dir.url is not None and "cloudspaces" in _output_dir.url: raise ValueError( f"The provided `output_dir` isn't valid. Found {_output_dir.path}." - " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." + "\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint) + _assert_dir_has_index_file( + _output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options + ) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -417,7 +433,9 @@ def optimize( num_workers = num_workers or _get_default_num_workers() state_dict = {rank: 0 for rank in range(num_workers)} - existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None + existing_index_file_content = ( + read_index_file_content(_output_dir, storage_options=storage_options) if mode == "append" else None + ) if existing_index_file_content is not None: for chunk in existing_index_file_content["chunks"]: @@ -439,6 +457,7 @@ def optimize( use_checkpoint=use_checkpoint, item_loader=item_loader, start_method=start_method, + storage_options=storage_options, ) with optimize_dns_context(True): @@ -451,6 +470,7 @@ def optimize( compression=compression, encryption=encryption, existing_index=existing_index_file_content, + storage_options=storage_options, ) ) return None @@ -468,7 +488,7 @@ def _listdir(folder: str) -> Tuple[str, List[str]]: class walk: """This class is an optimized version of os.walk for listing files and folders from cloud filesystem. - Note: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call. + .. note:: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call. """ @@ -519,13 +539,14 @@ class CopyInfo: new_filename: str -def merge_datasets(input_dirs: List[str], output_dir: str) -> None: +def merge_datasets(input_dirs: List[str], output_dir: str, storage_options: Optional[Dict] = {}) -> None: """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized dataset. - Arguments: + Args: input_dirs: A list of directories pointing to the existing optimized datasets. output_dir: The directory where the merged dataset would be stored. + storage_options: A dictionary of storage options to be passed to the fsspec library. """ if len(input_dirs) == 0: @@ -540,12 +561,14 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") - input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + input_dirs_file_content = [ + read_index_file_content(input_dir, storage_options=storage_options) for input_dir in resolved_input_dirs + ] if any(file_content is None for file_content in input_dirs_file_content): raise ValueError("One of the provided input_dir doesn't have an index file.") - output_dir_file_content = read_index_file_content(resolved_output_dir) + output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options=storage_options) if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") @@ -580,12 +603,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: _tqdm = _get_tqdm_iterator_if_available() for copy_info in _tqdm(copy_infos): - _apply_copy(copy_info, resolved_output_dir) + _apply_copy(copy_info, resolved_output_dir, storage_options=storage_options) - _save_index(index_json, resolved_output_dir) + _save_index(index_json, resolved_output_dir, storage_options=storage_options) -def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: +def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: if output_dir.url is None and copy_info.input_dir.url is None: assert copy_info.input_dir.path assert output_dir.path @@ -595,20 +618,15 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: shutil.copyfile(input_filepath, output_filepath) elif output_dir.url and copy_info.input_dir.url: - input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename)) - output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename)) - - s3 = S3Client() - s3.client.copy( - {"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")}, - output_obj.netloc, - output_obj.path.lstrip("/"), - ) + input_obj = os.path.join(copy_info.input_dir.url, copy_info.old_filename) + output_obj = os.path.join(output_dir.url, copy_info.new_filename) + + copy_file_or_directory(input_obj, output_obj, storage_options=storage_options) else: raise NotImplementedError -def _save_index(index_json: Dict, output_dir: Dir) -> None: +def _save_index(index_json: Dict, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None: if output_dir.url is None: assert output_dir.path with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: @@ -619,11 +637,6 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None: f.flush() - obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME)) - - s3 = S3Client() - s3.client.upload_file( - f.name, - obj.netloc, - obj.path.lstrip("/"), + upload_file_or_directory( + f.name, os.path.join(output_dir.url, _INDEX_FILENAME), storage_options=storage_options ) diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index b5d99cc6..beec5664 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -33,13 +33,11 @@ def get_node_rank(self) -> int: @abstractmethod def remap_items(self, items: Any, num_workers: int) -> List[Any]: - """This method is meant to remap the items provided by the users into items more adapted to be distributed.""" - pass + """Remap the items provided by the users into items more adapted to be distributed.""" @abstractmethod def read(self, item: Any) -> Any: """Read the data associated to an item.""" - pass class ParquetReader(BaseReader): diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 87a5b3e7..a13e863d 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -21,11 +21,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -import boto3 -import botocore - -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS from litdata.streaming.cache import Dir +from litdata.streaming.downloader import download_file_or_directory def _create_dataset( @@ -46,7 +44,7 @@ def _create_dataset( project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) user_id = os.getenv("LIGHTNING_USER_ID", None) - cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) + studio_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None) if project_id is None: @@ -64,7 +62,7 @@ def _create_dataset( try: client.dataset_service_create_dataset( body=ProjectIdDatasetsBody( - cloud_space_id=cloud_space_id if lightning_app_id is None else None, + cloud_space_id=studio_id if lightning_app_id is None else None, cluster_id=cluster_id, creator_id=user_id, empty=empty, @@ -183,7 +181,7 @@ def _get_work_dir() -> str: return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" -def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: +def read_index_file_content(output_dir: Dir, storage_options: Optional[Dict] = {}) -> Optional[Dict[str, Any]]: """Read the index file content.""" if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir should be a Dir object.") @@ -201,27 +199,26 @@ def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: # download the index file from s3, and read it obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.") - - # TODO: Add support for all cloud providers - s3 = boto3.client("s3") - - prefix = obj.path.lstrip("/").rstrip("/") + "/" + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError( + f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.path}." + ) # Check the index file exists try: # Create a temporary file with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file: temp_file_name = temp_file.name - s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name) + download_file_or_directory( + os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name, storage_options=storage_options + ) # Read data from the temporary file with open(temp_file_name) as temp_file: data = json.load(temp_file) # Delete the temporary file os.remove(temp_file_name) return data - except botocore.exceptions.ClientError: + except Exception as _e: return None @@ -251,27 +248,8 @@ def remove_uuid_from_filename(filepath: str) -> str: -> `checkpoint-0.json` """ - if not filepath.__contains__(".checkpoints"): return filepath # uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character return filepath[:-38] + ".json" - - -def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str: - s3_resource = boto3.resource("s3") - bucket = s3_resource.Bucket(bucket_name) - - saved_file_dir = "." - - for obj in bucket.objects.filter(Prefix=remote_directory_name): - local_filename = os.path.join(local_directory_name, obj.key) - - if not os.path.exists(os.path.dirname(local_filename)): - os.makedirs(os.path.dirname(local_filename)) - with open(local_filename, "wb") as f: - s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f) - saved_file_dir = os.path.dirname(local_filename) - - return saved_file_dir diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 816ef816..d045d5da 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -51,11 +51,12 @@ def __init__( """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements together in order to accelerate fetching. - Arguments: + Args: input_dir: The path to where the chunks will be or are stored. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of (start,end) of region of interest for each chunk. compression: The name of the algorithm to reduce the size of the chunks. + encryption: The encryption algorithm to use. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. item_loader: The object responsible to generate the chunk intervals and load an item froma chunk. diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py deleted file mode 100644 index d24803c3..00000000 --- a/src/litdata/streaming/client.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from time import time -from typing import Any, Dict, Optional - -import boto3 -import botocore -from botocore.credentials import InstanceMetadataProvider -from botocore.utils import InstanceMetadataFetcher - -from litdata.constants import _IS_IN_STUDIO - - -class S3Client: - # TODO: Generalize to support more cloud providers. - - def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: - self._refetch_interval = refetch_interval - self._last_time: Optional[float] = None - self._client: Optional[Any] = None - self._storage_options: dict = storage_options or {} - - def _create_client(self) -> None: - has_shared_credentials_file = ( - os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" - ) - - if has_shared_credentials_file or not _IS_IN_STUDIO: - self._client = boto3.client( - "s3", - **{ - "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), - **self._storage_options, - }, - ) - else: - provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) - credentials = provider.load() - self._client = boto3.client( - "s3", - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), - ) - - @property - def client(self) -> Any: - if self._client is None: - self._create_client() - self._last_time = time() - - # Re-generate credentials for EC2 - if self._last_time is None or (time() - self._last_time) > self._refetch_interval: - self._create_client() - self._last_time = time() - - return self._client diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index f3fd0e9b..22149589 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -25,7 +25,7 @@ class CombinedStreamingDataset(IterableDataset): - """The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of + """Enables to stream data from multiple StreamingDataset with the sampling ratio of your choice. Additionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable reusability @@ -43,15 +43,16 @@ def __init__( weights: Optional[Sequence[float]] = None, iterate_over_all: bool = True, ) -> None: - """ " - Arguments: + """Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice. + + Args: datasets: The list of the StreamingDataset to use. seed: The random seed to initialize the sampler weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. - """ + """ self._check_datasets(datasets) self._seed = seed diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 997d2a83..df0ea012 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -35,14 +35,14 @@ def __init__( region_of_interest: Optional[List[Tuple[int, int]]] = None, storage_options: Optional[Dict] = {}, ) -> None: - """The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its - chunk. + """Reads the index files associated a chunked dataset and enables to map an index to its chunk. Arguments: cache_dir: The path to cache folder. serializers: The serializers used to serialize and deserialize the chunks. remote_dir: The path to a remote folder where the data are located. The scheme needs to be added to the path. + item_loader: The item loader used to load the data from the chunks. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of {start,end} of region of interest for each chunk. storage_options: Additional connection options for accessing storage services. diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 3f932fae..1612f55d 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -80,11 +80,11 @@ def __init__( ): """The `CacheDataset` is a dataset wrapper to provide a beginner experience with the Cache. - Arguments: + Args: dataset: The dataset of the user cache_dir: The folder where the chunks are written to. chunk_bytes: The maximal number of bytes to write within a chunk. - chunk_sie: The maximal number of items to write to a chunk. + chunk_size: The maximal number of items to write to a chunk. compression: The compression algorithm to use to reduce the size of the chunk. """ @@ -103,7 +103,7 @@ def __getitem__(self, index: int) -> Any: if not _equal_items(data_1, data2): raise ValueError( f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}." - " HINT: Use the `litdata.cache.Cache` directly within your dataset." + "\n HINT: Use the `litdata.cache.Cache` directly within your dataset." ) self._is_deterministic = True self._cache[index] = data_1 @@ -115,7 +115,7 @@ class CacheCollateFn: During the chunking phase, there is no need to return any data from the DataLoader reducing some time. - Additionally, if the user makes their __getitem__ asynchronous, the collate executes them in parallel. + Additionally, if the user makes their __getitem__ asynchronous, collate executes them in parallel. """ diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 5c57cd69..a6f70bf5 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -59,7 +59,7 @@ def __init__( ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. - Arguments: + Args: input_dir: Path to the folder where the input data is stored. item_loader: The logic to load an item from a chunk. shuffle: Whether to shuffle the data. @@ -155,7 +155,8 @@ def set_epoch(self, current_epoch: int) -> None: def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( - input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url + input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, + storage_options=self.storage_options, ) if cache_path is not None: self.input_dir.path = cache_path @@ -177,7 +178,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if not cache.filled: raise ValueError( f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file." - " HINT: Did you successfully optimize a dataset to the provided `input_dir`?" + "\n HINT: Did you successfully optimize a dataset to the provided `input_dir`?" ) return cache @@ -438,7 +439,8 @@ def _validate_state_dict(self) -> None: # In this case, validate the cache folder is the same. if _should_replace_path(state["input_dir_path"]): cache_path = _try_create_cache_dir( - input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"] + input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"], + storage_options=self.storage_options, ) if cache_path != self.input_dir.path: raise ValueError( diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index fa4e5fe3..463ab576 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -16,23 +16,32 @@ import shutil import subprocess from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from urllib import parse +import fsspec from filelock import FileLock, Timeout -from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME -from litdata.streaming.client import S3Client +from litdata.constants import _INDEX_FILENAME + +# from litdata.streaming.client import S3Client + +_USE_S5CMD_FOR_S3 = True class Downloader(ABC): def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + self, + cloud_provider: str, + remote_dir: str, + cache_dir: str, + chunks: List[Dict[str, Any]], + storage_options: Optional[Dict] = {}, ): self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks - self._storage_options = storage_options or {} + self.fs = fsspec.filesystem(cloud_provider, **storage_options) def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] @@ -44,164 +53,195 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: pass -class S3Downloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - super().__init__(remote_dir, cache_dir, chunks, storage_options) - self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - - if not self._s5cmd_available: - self._client = S3Client(storage_options=self._storage_options) - +class LocalDownloader(Downloader): def download_file(self, remote_filepath: str, local_filepath: str) -> None: - obj = parse.urlparse(remote_filepath) - - if obj.scheme != "s3": - raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - - if os.path.exists(local_filepath): - return + if not os.path.exists(remote_filepath): + raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - if self._s5cmd_available: - proc = subprocess.Popen( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - ) - proc.wait() - else: - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - # try: - # with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) + with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0): + if remote_filepath != local_filepath and not os.path.exists(local_filepath): + # make an atomic operation to be safe + temp_file_path = local_filepath + ".tmp" + shutil.copy(remote_filepath, temp_file_path) + os.rename(temp_file_path, local_filepath) + with contextlib.suppress(Exception): + os.remove(local_filepath + ".lock") except Timeout: - # another process is responsible to download that file, continue pass -class GCPDownloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - if not _GOOGLE_STORAGE_AVAILABLE: - raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) +class LocalDownloaderWithCache(LocalDownloader): + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + remote_filepath = remote_filepath.replace("local:", "") + super().download_file(remote_filepath, local_filepath) - super().__init__(remote_dir, cache_dir, chunks, storage_options) - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - from google.cloud import storage +def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> None: + _s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - obj = parse.urlparse(remote_filepath) + if _s5cmd_available is False: + raise ModuleNotFoundError(str(_s5cmd_available)) - if obj.scheme != "gs": - raise ValueError(f"Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote_filepath}") + obj = parse.urlparse(remote_filepath) - if os.path.exists(local_filepath): - return + if obj.scheme != "s3": + raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for {remote_filepath}") - try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - bucket_name = obj.netloc - key = obj.path - # Remove the leading "/": - if key[0] == "/": - key = key[1:] - - client = storage.Client(**self._storage_options) - bucket = client.bucket(bucket_name) - blob = bucket.blob(key) - blob.download_to_filename(local_filepath) - except Timeout: - # another process is responsible to download that file, continue - pass + if os.path.exists(local_filepath): + return + try: + with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", + shell=True, + stdout=subprocess.PIPE, + ) + proc.wait() + except Timeout: + # another process is responsible to download that file, continue + pass -class AzureDownloader(Downloader): - def __init__( - self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} - ): - if not _AZURE_STORAGE_AVAILABLE: - raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE)) - super().__init__(remote_dir, cache_dir, chunks, storage_options) +_DOWNLOADERS = { + "s3://": "s3", + "gs://": "gs", + "azure://": "abfs", + "abfs://": "abfs", + "local:": "file", + "": "file", +} + +_DEFAULT_STORAGE_OPTIONS = { + "s3": {"config_kwargs": {"retries": {"max_attempts": 1000, "mode": "adaptive"}}}, +} - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - from azure.storage.blob import BlobServiceClient - obj = parse.urlparse(remote_filepath) +def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict: + if storage_options is None: + storage_options = {} + if cloud_provider in _DEFAULT_STORAGE_OPTIONS: + return {**_DEFAULT_STORAGE_OPTIONS[cloud_provider], **storage_options} + return storage_options - if obj.scheme != "azure": - raise ValueError( - f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}" - ) - if os.path.exists(local_filepath): - return +class FsspecDownloader(Downloader): + def __init__( + self, + cloud_provider: str, + remote_dir: str, + cache_dir: str, + chunks: List[Dict[str, Any]], + storage_options: Optional[Dict] = {}, + ): + remote_dir = remote_dir.replace("local:", "") + self.is_local = False + storage_options = get_complete_storage_options(cloud_provider, storage_options) + super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) + self.cloud_provider = cloud_provider + self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 + def download_file(self, remote_filepath: str, local_filepath: str) -> None: + if os.path.exists(local_filepath) or remote_filepath == local_filepath: + return + if self.use_s5cmd and _USE_S5CMD_FOR_S3: + download_s3_file_via_s5cmd(remote_filepath, local_filepath) + return try: - with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): - service = BlobServiceClient(**self._storage_options) - blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip("/")) - with open(local_filepath, "wb") as download_file: - blob_data = blob_client.download_blob() - blob_data.readinto(download_file) - + with FileLock(local_filepath + ".lock", timeout=3): + self.fs.get(remote_filepath, local_filepath, recursive=True) + # remove the lock file + if os.path.exists(local_filepath + ".lock"): + os.remove(local_filepath + ".lock") except Timeout: # another process is responsible to download that file, continue pass -class LocalDownloader(Downloader): - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - if not os.path.exists(remote_filepath): - raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") +def does_file_exist( + remote_filepath: str, cloud_provider: Union[str, None] = None, storage_options: Optional[Dict] = {} +) -> bool: + if cloud_provider is None: + cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(cloud_provider, storage_options) + fs = fsspec.filesystem(cloud_provider, **storage_options) + return fs.exists(remote_filepath) + + +def list_directory( + remote_directory: str, + detail: bool = False, + cloud_provider: Optional[str] = None, + storage_options: Optional[Dict] = {}, +) -> List[str]: + """Returns a list of filenames in a remote directory.""" + if cloud_provider is None: + cloud_provider = get_cloud_provider(remote_directory) + storage_options = get_complete_storage_options(cloud_provider, storage_options) + fs = fsspec.filesystem(cloud_provider, **storage_options) + return fs.ls(remote_directory, detail=detail) # just return the filenames + + +def download_file_or_directory(remote_filepath: str, local_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """Download a file from the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + use_s5cmd = fs_cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 + if use_s5cmd and _USE_S5CMD_FOR_S3: + download_s3_file_via_s5cmd(remote_filepath, local_filepath) + return + try: + with FileLock(local_filepath + ".lock", timeout=3): + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.get(remote_filepath, local_filepath, recursive=True) + except Timeout: + # another process is responsible to download that file, continue + pass - try: - with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0): - if remote_filepath != local_filepath and not os.path.exists(local_filepath): - # make an atomic operation to be safe - temp_file_path = local_filepath + ".tmp" - shutil.copy(remote_filepath, temp_file_path) - os.rename(temp_file_path, local_filepath) - with contextlib.suppress(Exception): - os.remove(local_filepath + ".lock") - except Timeout: - pass +def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """Upload a file to the remote cloud storage.""" + try: + with FileLock(local_filepath + ".lock", timeout=3): + fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.put(local_filepath, remote_filepath, recursive=True) + except Timeout: + # another process is responsible to upload that file, continue + pass -class LocalDownloaderWithCache(LocalDownloader): - def download_file(self, remote_filepath: str, local_filepath: str) -> None: - remote_filepath = remote_filepath.replace("local:", "") - super().download_file(remote_filepath, local_filepath) +def copy_file_or_directory( + remote_filepath_src: str, remote_filepath_tg: str, storage_options: Optional[Dict] = {} +) -> None: + """Copy a file from src to target on the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath_src) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) -_DOWNLOADERS = { - "s3://": S3Downloader, - "gs://": GCPDownloader, - "azure://": AzureDownloader, - "local:": LocalDownloaderWithCache, - "": LocalDownloader, -} + +def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: + """Remove a file from the remote cloud storage.""" + fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) + fs = fsspec.filesystem(fs_cloud_provider, **storage_options) + fs.rm(remote_filepath, recursive=True) + + +def get_cloud_provider(remote_filepath: str) -> str: + for k, fs_cloud_provider in _DOWNLOADERS.items(): + if str(remote_filepath).startswith(k): + return fs_cloud_provider + raise ValueError(f"The provided `remote_filepath` {remote_filepath} doesn't have a downloader associated.") def get_downloader_cls( remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} ) -> Downloader: - for k, cls in _DOWNLOADERS.items(): + for k, fs_cloud_provider in _DOWNLOADERS.items(): if str(remote_dir).startswith(k): - return cls(remote_dir, cache_dir, chunks, storage_options) + return FsspecDownloader(fs_cloud_provider, remote_dir, cache_dir, chunks, storage_options) raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 877716cb..acc50d2b 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -70,18 +70,16 @@ def state_dict(self) -> Dict: @abstractmethod def generate_intervals(self) -> List[Interval]: - """Returns a list of intervals: [chunk_start, - region_of_interest_start, region_of_interest_end, chunk_end] + """Returns a list of intervals. - region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read. + The structure is: [chunk_start, region_of_interest_start, region_of_interest_end, chunk_end] + region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read. """ - pass @abstractmethod def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: """Logic to load the chunk in background to gain some time.""" - pass @abstractmethod def load_item_from_chunk( @@ -93,12 +91,10 @@ def load_item_from_chunk( chunk_bytes: int, ) -> Any: """Returns an item loaded from a chunk.""" - pass @abstractmethod def delete(self, chunk_index: int, chunk_filepath: str) -> None: """Delete a chunk from the local filesystem.""" - pass @abstractmethod def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any: @@ -168,7 +164,6 @@ def _load_encrypted_data( self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Optional[Encryption] ) -> bytes: """Load and decrypt data from chunk based on the encryption configuration.""" - # Validate the provided encryption object against the expected configuration. self._validate_encryption(encryption) @@ -261,11 +256,10 @@ class TokensLoader(BaseItemLoader): def __init__(self, block_size: Optional[int] = None): """The Tokens Loader is an optimizer item loader for NLP. - Arguments: + Args: block_size: The context length to use during training. """ - super().__init__() self._block_size = block_size self._mmaps: Dict[int, np.memmap] = {} diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index bc02380a..60cd2965 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -19,7 +19,6 @@ from threading import Event, Thread from typing import Any, Dict, List, Optional, Tuple, Union -from litdata.constants import _TORCH_GREATER_EQUAL_2_1_0 from litdata.streaming.config import ChunksConfig, Interval from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader from litdata.streaming.sampler import ChunkedIndex @@ -29,9 +28,6 @@ warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*") -if _TORCH_GREATER_EQUAL_2_1_0: - pass - logger = Logger(__name__) @@ -174,7 +170,7 @@ def __init__( ) -> None: """The BinaryReader enables to read chunked dataset in an efficient way. - Arguments: + Args: cache_dir: The path to cache folder. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of {start,end} of region of interest for each chunk. diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 874497a6..a1781ccc 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,13 +20,15 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from urllib import parse -import boto3 -import botocore - -from litdata.constants import _LIGHTNING_SDK_AVAILABLE +from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_CLOUD_PROVIDERS +from litdata.streaming.downloader import ( + does_file_exist, + list_directory, + remove_file_or_directory, +) if TYPE_CHECKING: from lightning_sdk import Machine @@ -52,9 +54,15 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir: assert isinstance(dir_path, str) - cloud_prefixes = ("s3://", "gs://", "azure://") - if dir_path.startswith(cloud_prefixes): - return Dir(path=None, url=dir_path) + cloud_prefixes = _SUPPORTED_CLOUD_PROVIDERS + dir_scheme = parse.urlparse(dir_path).scheme + if bool(dir_scheme) and dir_scheme not in ["c", "d", "e", "f"]: # prevent windows `c:\\` and `d:\\` + if any(dir_path.startswith(cloud_prefix) for cloud_prefix in cloud_prefixes): + return Dir(path=None, url=dir_path) + raise ValueError( + f"The provided dir_path `{dir_path}` is not supported.", + f" HINT: Only the following cloud providers are supported: {_SUPPORTED_CLOUD_PROVIDERS}.", + ) if dir_path.startswith("local:"): return Dir(path=None, url=dir_path) @@ -88,14 +96,11 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa if target_id is not None and cloudspace.id == target_id: return True - if ( + return bool( cloudspace.display_name is not None and target_name is not None and cloudspace.display_name.lower() == target_name.lower() - ): - return True - - return False + ) def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir: @@ -108,10 +113,10 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) if cluster_id is None: - raise RuntimeError("The `cluster_id` couldn't be found from the environment variables.") + raise RuntimeError("The `LIGHTNING_CLUSTER_ID` couldn't be found from the environment variables.") if project_id is None: - raise RuntimeError("The `project_id` couldn't be found from the environment variables.") + raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.") clusters = client.cluster_service_list_project_clusters(project_id).clusters @@ -147,7 +152,7 @@ def _resolve_s3_connections(dir_path: str) -> Dir: # Get the ids from env variables project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) if project_id is None: - raise RuntimeError("The `project_id` couldn't be found from the environment variables.") + raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.") target_name = dir_path.split("/")[3] @@ -169,16 +174,16 @@ def _resolve_datasets(dir_path: str) -> Dir: # Get the ids from env variables cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) - cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) + studio_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) if cluster_id is None: - raise RuntimeError("The `cluster_id` couldn't be found from the environment variables.") + raise RuntimeError("The `LIGHTNING_CLUSTER_ID` couldn't be found from the environment variables.") if project_id is None: - raise RuntimeError("The `project_id` couldn't be found from the environment variables.") + raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.") - if cloud_space_id is None: - raise RuntimeError("The `cloud_space_id` couldn't be found from the environment variables.") + if studio_id is None: + raise RuntimeError("The `LIGHTNING_CLOUD_SPACE_ID` couldn't be found from the environment variables.") clusters = client.cluster_service_list_project_clusters(project_id).clusters @@ -187,17 +192,17 @@ def _resolve_datasets(dir_path: str) -> Dir: for cloudspace in client.cloud_space_service_list_cloud_spaces( project_id=project_id, cluster_id=cluster_id ).cloudspaces - if cloudspace.id == cloud_space_id + if cloudspace.id == studio_id ] if not target_cloud_space: - raise ValueError(f"We didn't find any matching Studio for the provided id `{cloud_space_id}`.") + raise ValueError(f"We didn't find any matching Studio for the provided id `{studio_id}`.") target_cluster = [cluster for cluster in clusters if cluster.id == target_cloud_space[0].cluster_id] if not target_cluster: raise ValueError( - f"We didn't find a matching cluster associated with the id {target_cloud_space[0].cluster_id}." + f"We didn't find a matching cluster associated with the id `{target_cloud_space[0].cluster_id}`." ) return Dir( @@ -209,37 +214,38 @@ def _resolve_datasets(dir_path: str) -> Dir: ) -def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool = False) -> None: +def _assert_dir_is_empty( + output_dir: Dir, append: bool = False, overwrite: bool = False, storage_options: Optional[Dict] = {} +) -> None: if not isinstance(output_dir, Dir): - raise ValueError("The provided output_dir isn't a Dir Object.") + raise ValueError("The provided output_dir isn't a `Dir` Object.") if output_dir.url is None: return obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") - s3 = boto3.client("s3") - - objects = s3.list_objects_v2( - Bucket=obj.netloc, - Delimiter="/", - Prefix=obj.path.lstrip("/").rstrip("/") + "/", - ) + try: + object_list = list_directory(output_dir.url, storage_options=storage_options) + except FileNotFoundError: + return # We aren't alloweing to add more data - # TODO: Add support for `append` and `overwrite`. - if objects["KeyCount"] > 0: + if object_list is not None and len(object_list) > 0: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains data and datasets are meant to be immutable." - " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" + "\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" ) def _assert_dir_has_index_file( - output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False + output_dir: Dir, + mode: Optional[Literal["append", "overwrite"]] = None, + use_checkpoint: bool = False, + storage_options: Optional[Dict] = {}, ) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") @@ -261,8 +267,8 @@ def _assert_dir_has_index_file( if os.path.exists(index_file) and mode is None: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains an optimized immutable datasets." - " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" - " HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." + "\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" + "\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) # delete index.json file and chunks @@ -283,44 +289,29 @@ def _assert_dir_has_index_file( obj = parse.urlparse(output_dir.url) - if obj.scheme != "s3": - raise ValueError(f"The provided folder should start with s3://. Found {output_dir.url}.") - - s3 = boto3.client("s3") + if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS: + raise ValueError(f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.url}.") - prefix = obj.path.lstrip("/").rstrip("/") + "/" - - objects = s3.list_objects_v2( - Bucket=obj.netloc, - Delimiter="/", - Prefix=prefix, - ) + objects_list = [] + with suppress(FileNotFoundError): + objects_list = list_directory(output_dir.url, storage_options=storage_options) # No files are found in this folder - if objects["KeyCount"] == 0: + if objects_list is None or len(objects_list) == 0: return # Check the index file exists - try: - s3.head_object(Bucket=obj.netloc, Key=os.path.join(prefix, "index.json")) - has_index_file = True - except botocore.exceptions.ClientError: - has_index_file = False + has_index_file = does_file_exist(os.path.join(output_dir.url, "index.json"), storage_options=storage_options) if has_index_file and mode is None: raise RuntimeError( f"The provided output_dir `{output_dir.path}` already contains an optimized immutable datasets." - " HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" - " HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." + "\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?" + "\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) - # Delete all the files (including the index file in overwrite mode) - bucket_name = obj.netloc - s3 = boto3.resource("s3") - if mode == "overwrite" or (mode is None and not use_checkpoint): - for obj in s3.Bucket(bucket_name).objects.filter(Prefix=prefix): - s3.Object(bucket_name, obj.key).delete() + remove_file_or_directory(output_dir.url, storage_options=storage_options) def _get_lightning_cloud_url() -> str: @@ -348,7 +339,6 @@ def _execute( command: Optional[str] = None, ) -> None: """Remotely execute the current operator.""" - if not _LIGHTNING_SDK_AVAILABLE: raise ModuleNotFoundError("The `lightning_sdk` is required.") diff --git a/src/litdata/streaming/sampler.py b/src/litdata/streaming/sampler.py index f25337d6..d135979b 100644 --- a/src/litdata/streaming/sampler.py +++ b/src/litdata/streaming/sampler.py @@ -45,7 +45,7 @@ def __init__( If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset If the cache is filled, it acts as normal BatchSampler. - Arguments: + Args: dataset_size: The size of the dataset. num_replicas: The number of processes involves in the distributed training. global_rank: The global_rank of the given process diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 9d5d511f..f453f538 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -17,6 +17,7 @@ import tempfile from abc import ABC, abstractmethod from collections import OrderedDict +from contextlib import suppress from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union @@ -28,7 +29,6 @@ if TYPE_CHECKING: from PIL.JpegImagePlugin import JpegImageFile - _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") @@ -69,6 +69,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: @classmethod def deserialize(cls, data: bytes) -> Any: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") from PIL import Image idx = 3 * 4 @@ -93,6 +95,9 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + if not _PIL_AVAILABLE: + raise ModuleNotFoundError("PIL is required. Run `pip install pillow`") + from PIL import Image from PIL.GifImagePlugin import GifImageFile from PIL.JpegImagePlugin import JpegImageFile @@ -102,7 +107,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: if isinstance(item, JpegImageFile): if not hasattr(item, "filename"): raise ValueError( - "The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method." + "The JPEG Image's filename isn't defined." + "\n HINT: Open the image in your Dataset `__getitem__` method." ) if item.filename and os.path.isfile(item.filename): # read the content of the file directly @@ -120,7 +126,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: buff.seek(0) return buff.read(), None - raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.") + raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.") def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: if _TORCH_VISION_AVAILABLE: @@ -128,11 +134,9 @@ def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]: from torchvision.transforms.functional import pil_to_tensor array = torch.frombuffer(data, dtype=torch.uint8) - try: + # Note: Some datasets like Imagenet contains some PNG images with JPEG extension, so we fallback to PIL + with suppress(RuntimeError): return decode_jpeg(array) - except RuntimeError: - # Note: Some datasets like Imagenet contains some PNG images with JPEG extension, so we fallback to PIL - pass img = PILSerializer.deserialize(data) if _TORCH_VISION_AVAILABLE: @@ -184,7 +188,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: shape = [] for shape_idx in range(shape_size): shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) - idx_start = 8 + 4 * (shape_idx + 1) + idx_start = 8 + 4 * shape_size idx_end = len(data) if idx_end > idx_start: tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype) @@ -250,7 +254,7 @@ def deserialize(self, data: bytes) -> np.ndarray: shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) # deserialize the numpy array bytes - tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype) if tensor.shape == shape: return tensor return np.reshape(tensor, shape) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 752b21d0..90e1085b 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -59,7 +59,8 @@ def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk class NoShuffle(Shuffle): """NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items if drop_last - is True.""" + is True. + """ @lru_cache(maxsize=10) def get_chunks_and_intervals_per_workers( @@ -83,13 +84,10 @@ class FullShuffle(Shuffle): """FullShuffle shuffles the chunks and associates them to the ranks. As the number of items in a chunk varies, it is possible for a rank to end up with more or less items. - To ensure the same fixed dataset length for all ranks while dropping as few items as possible, - we adopt the following strategy. We compute the maximum number of items per rank (M) and iterate through the chunks and ranks - until we have associated at least M items per rank. As a result, we lose at most (number of ranks) items. However, as some chunks are shared across ranks. This leads to diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 25dded88..10f1f11f 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -58,14 +58,16 @@ def __init__( ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. - Arguments: + Args: cache_dir: The path to where the chunks will be saved. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. compression: The compression algorithm to use. encryption: The encryption algorithm to use. + follow_tensor_dimension: Whether to follow the tensor dimension when serializing the data. serializers: Provide your own serializers. chunk_index: The index of the chunk to start from. + item_loader: The object responsible to generate the chunk intervals and load an item from a chunk. """ self._cache_dir = cache_dir @@ -155,7 +157,6 @@ def get_config(self) -> Dict[str, Any]: def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]: """Serialize a dictionary into its binary format.""" - # Flatten the items provided by the users flattened, data_spec = tree_flatten(items) @@ -289,8 +290,8 @@ def __setitem__(self, index: int, items: Any) -> None: def add_item(self, index: int, items: Any) -> Optional[str]: """Given an index and items will serialize the items and store an Item object to the growing - `_serialized_items`.""" - + `_serialized_items`. + """ if index in self._serialized_items: raise ValueError(f"The provided index {index} already exists in the cache.") @@ -413,7 +414,8 @@ def done(self) -> List[str]: def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: """Once all the workers have written their own index, the merge function is responsible to read and merge them - into a single index.""" + into a single index. + """ num_workers = num_workers or 1 # Only for non rank 0 @@ -443,7 +445,7 @@ def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Option """Once all the workers have written their own index, the merge function is responsible to read and merge them into a single index. - Arguments: + Args: node_rank: The node rank of the index file existing_index: Existing index to be added to the newly created one. diff --git a/src/litdata/utilities/broadcast.py b/src/litdata/utilities/broadcast.py index 770aabb3..8d94c5c0 100644 --- a/src/litdata/utilities/broadcast.py +++ b/src/litdata/utilities/broadcast.py @@ -48,8 +48,7 @@ def _response(r: Any, *args: Any, **kwargs: Any) -> Any: class _HTTPClient: - """A wrapper class around the requests library which handles chores like logging, retries, and timeouts - automatically.""" + """A wrapper around the requests library which handles chores like logging, retries, and timeouts automatically.""" def __init__( self, diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index ca43bedd..a23d9e3f 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -64,7 +64,7 @@ def subsample_streaming_dataset( else: raise ValueError( f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file." - " HINT: Did you successfully optimize a dataset to the provided `input_dir`?" + "\n HINT: Did you successfully optimize a dataset to the provided `input_dir`?" ) assert len(original_chunks) > 0, f"No chunks found in the `{input_dir}/index.json` file" @@ -161,7 +161,7 @@ def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Di def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoader] = None) -> List[Tuple[int, int]]: - "Generates default region_of_interest for chunks." + """Generates default region_of_interest for chunks.""" roi = [] if isinstance(item_loader, TokensLoader): @@ -207,13 +207,14 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]: """Adapt mds shard-based index data to chunk-based format for compatibility. - For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming + For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming. Args: data (Dict[str, Any]): The original index data containing shards. Returns: Dict[str, Any]: Adapted index data with chunks format. + """ chunks = [] shards = data["shards"] diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index b101f57f..74f72c46 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -36,7 +36,7 @@ def __init__(self, world_size: int, global_rank: int, num_nodes: int): def detect(cls) -> "_DistributedEnv": """Tries to automatically detect the distributed environment parameters. - Note: + .. note:: This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) as the distributed framework won't be initialized there. It will default to 1 distributed process in this case. @@ -103,7 +103,7 @@ def __init__(self, world_size: int, rank: int): def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv": """Automatically detects the number of workers and the current rank. - Note: + .. note:: This only works reliably within a dataloader worker as otherwise the necessary information won't be present. In such a case it will default to 1 worker @@ -146,7 +146,7 @@ def from_args( """Generates the Environment class by already given arguments instead of detecting them. Args: - dist_world_size: The worldsize used for distributed training (=total number of distributed processes) + dist_world_size: The world-size used for distributed training (=total number of distributed processes) global_rank: The distributed global rank of the current process num_workers: The number of workers per distributed training process current_worker_rank: The rank of the current worker within the number of workers of @@ -162,7 +162,7 @@ def from_args( def num_shards(self) -> int: """Returns the total number of shards. - Note: + .. note:: This may not be accurate in a non-dataloader-worker process like the main training process as it doesn't necessarily know about the number of dataloader workers. @@ -175,7 +175,7 @@ def num_shards(self) -> int: def shard_rank(self) -> int: """Returns the rank of the current process wrt. the total number of shards. - Note: + .. note:: This may not be accurate in a non-dataloader-worker process like the main training process as it doesn't necessarily know about the number of dataloader workers. diff --git a/src/litdata/utilities/packing.py b/src/litdata/utilities/packing.py index 3b1c8480..0c11d9a7 100644 --- a/src/litdata/utilities/packing.py +++ b/src/litdata/utilities/packing.py @@ -17,8 +17,8 @@ def _pack_greedily(items: List[Any], weights: List[int], num_bins: int) -> Tuple[Dict[int, List[Any]], Dict[int, int]]: """Greedily pack items with given weights into bins such that the total weight of each bin is roughly equally - distributed among all bins.""" - + distributed among all bins. + """ if len(items) != len(weights): raise ValueError(f"Items and weights must have the same length, got {len(items)} and {len(weights)}.") if any(w <= 0 for w in weights): diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 9133ae05..f1861fe9 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -51,7 +51,8 @@ def _group_chunks_by_nodes( num_workers_per_process: int, ) -> List[List[int]]: """Takes a list representing chunks grouped by worker (global worker id across ranks and nodes) and returns a list - in which the chunks are grouped by node.""" + in which the chunks are grouped by node. + """ chunk_indexes_per_nodes: Any = [[] for _ in range(num_nodes)] num_processes_per_node = world_size // num_nodes for worker_global_id, chunks in enumerate(chunks_per_workers): @@ -135,7 +136,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( to only let the worker delete a chunk when that worker is the last to read from it. """ - # Shared chunks across all workers and ranks shared_chunks = _get_shared_chunks(workers_chunks) @@ -279,7 +279,8 @@ def _aggregate_shared_chunks_per_rank( def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable: Dict[int, List[int]]) -> Dict[int, List[int]]: """Takes a dictionary mapping a chunk index to a list of workers and inverts the map such that it returns a - dictionary mapping a worker to a list of chunk indexes (that should not be deleted by that worker).""" + dictionary mapping a worker to a list of chunk indexes (that should not be deleted by that worker). + """ map_node_worker_rank_to_chunk_indexes: Dict[int, List[int]] = {} for chunk_index, worker_ids in to_disable.items(): for worker_idx in worker_ids: diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index d31fb076..a44b8b5c 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -20,9 +20,10 @@ def train_test_split( These subsets can be used for training, testing, and validation purposes. Args: - streaming_dataset (StreamingDataset): An instance of StreamingDataset that needs to be split. - splits (List[float]): A list of floats representing the proportion of data to be allocated to each split + streaming_dataset: An instance of StreamingDataset that needs to be split. + splits: A list of floats representing the proportion of data to be allocated to each split (e.g., [0.8, 0.1, 0.1] for 80% training, 10% testing, and 10% validation). + seed: An integer used to seed the random number generator for reproducibility. Returns: List[StreamingDataset]: A list of StreamingDataset instances, where each element represents a split of the diff --git a/tests/conftest.py b/tests/conftest.py index 32a86645..a4b1f707 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ @pytest.fixture(autouse=True) def teardown_process_group(): # noqa: PT004 - """Ensures that the distributed process group gets closed before the next test runs.""" + """Ensures distributed process group gets closed before the next test runs.""" yield if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() @@ -84,7 +84,7 @@ def lightning_sdk_mock(monkeypatch): @pytest.fixture(autouse=True) def _thread_police(): - """Attempts to stop left-over threads to avoid test interactions. + """Attempts stopping left-over threads to avoid test interactions. Adapted from PyTorch Lightning. diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 543b6909..ef111541 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -109,20 +109,16 @@ def fn(*_, **__): remove_queue = mock.MagicMock() - s3_client = mock.MagicMock() - called = False - def copy_file(local_filepath, *args): + def copy_file(local_filepath, *args, **kwargs): nonlocal called called = True from shutil import copyfile copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath))) - s3_client.client.upload_file = copy_file - - monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client)) + monkeypatch.setattr(data_processor_module, "upload_file_or_directory", copy_file) assert os.listdir(remote_output_dir) == [] @@ -217,32 +213,28 @@ def test_wait_for_disk_usage_higher_than_threshold(): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -def test_wait_for_file_to_exist(): - import botocore - - s3 = mock.MagicMock() - obj = mock.MagicMock() +def test_wait_for_file_to_exist(monkeypatch): raise_error = [True, True, False] def fn(*_, **__): value = raise_error.pop(0) if value: - raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception return - s3.client.head_object = fn + monkeypatch.setattr(data_processor_module, "does_file_exist", fn) - _wait_for_file_to_exist(s3, obj, sleep_time=0.01) + _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) assert len(raise_error) == 0 def fn(*_, **__): raise ValueError("HERE") - s3.client.head_object = fn + monkeypatch.setattr(data_processor_module, "does_file_exist", fn) with pytest.raises(ValueError, match="HERE"): - _wait_for_file_to_exist(s3, obj, sleep_time=0.01) + _wait_for_file_to_exist("s3://some-dummy-bucket/some-dummy-key", sleep_time=0.01) def test_cache_dir_cleanup(tmpdir, monkeypatch): @@ -460,8 +452,7 @@ def _broadcast_object(self, obj: Any) -> Any: condition=(not _PIL_AVAILABLE or sys.platform == "win32" or sys.platform == "linux"), reason="Requires: ['pil']" ) def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, monkeypatch): - """This test ensures the data optimizer works in a fully distributed settings.""" - + """Ensures the data optimizer works in a fully distributed settings.""" seed_everything(42) monkeypatch.setattr(data_processor_module.os, "_exit", mock.MagicMock()) @@ -1025,11 +1016,10 @@ def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_map_error_when_not_empty(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - monkeypatch.setattr(resolver, "boto3", boto3) + def mock_list_directory(*args, **kwargs): + return ["a.txt", "b.txt"] + + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="data and datasets are meant to be immutable"): map( diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py deleted file mode 100644 index 78ea919d..00000000 --- a/tests/streaming/test_client.py +++ /dev/null @@ -1,97 +0,0 @@ -import sys -from time import sleep, time -from unittest import mock - -import pytest -from litdata.streaming import client - - -def test_s3_client_with_storage_options(monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) - - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) - - storage_options = { - "region_name": "us-west-2", - "endpoint_url": "https://custom.endpoint", - "config": botocore.config.Config(retries={"max_attempts": 100}), - } - s3_client = client.S3Client(storage_options=storage_options) - - assert s3_client.client - - boto3.client.assert_called_with( - "s3", - region_name="us-west-2", - endpoint_url="https://custom.endpoint", - config=botocore.config.Config(retries={"max_attempts": 100}), - ) - - s3_client = client.S3Client() - - assert s3_client.client - - boto3.client.assert_called_with( - "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) - ) - - -def test_s3_client_without_cloud_space_id(monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) - - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) - - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - - s3 = client.S3Client(1) - assert s3.client - assert s3.client - assert s3.client - assert s3.client - assert s3.client - - boto3.client.assert_called_once() - - -@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") -@pytest.mark.parametrize("use_shared_credentials", [False, True, None]) -def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): - boto3 = mock.MagicMock() - monkeypatch.setattr(client, "boto3", boto3) - - botocore = mock.MagicMock() - monkeypatch.setattr(client, "botocore", botocore) - - if isinstance(use_shared_credentials, bool): - monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") - monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") - monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") - - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - - s3 = client.S3Client(1) - assert s3.client - assert s3.client - boto3.client.assert_called_once() - sleep(1 - (time() - s3._last_time)) - assert s3.client - assert s3.client - assert len(boto3.client._mock_mock_calls) == 6 - sleep(1 - (time() - s3._last_time)) - assert s3.client - assert s3.client - assert len(boto3.client._mock_mock_calls) == 9 - - assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3 diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 78cb0eaa..b87f9ffd 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -889,11 +889,10 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) -@pytest.mark.timeout(60) +@pytest.mark.timeout(120) @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): - """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have - the same size.""" + """Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") optimize_data_cache_dir = str(tmpdir / "optimize_data_cache") optimize_cache_dir = str(tmpdir / "optimize_cache") diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 7c79afe5..97368c0c 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -1,84 +1,19 @@ import os -from unittest import mock from unittest.mock import MagicMock from litdata.streaming.downloader import ( - AzureDownloader, - GCPDownloader, LocalDownloaderWithCache, - S3Downloader, shutil, - subprocess, ) -def test_s3_downloader_fast(tmpdir, monkeypatch): - monkeypatch.setattr(os, "system", MagicMock(return_value=0)) - popen_mock = MagicMock() - monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) - downloader = S3Downloader(tmpdir, tmpdir, []) - downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) - popen_mock.wait.assert_called() - - -@mock.patch("litdata.streaming.downloader._GOOGLE_STORAGE_AVAILABLE", True) -def test_gcp_downloader(tmpdir, monkeypatch, google_mock): - # Create mock objects - mock_client = MagicMock() - mock_bucket = MagicMock() - mock_blob = MagicMock() - mock_blob.download_to_filename = MagicMock() - - # Patch the storage client to return the mock client - google_mock.cloud.storage.Client = MagicMock(return_value=mock_client) - - # Configure the mock client to return the mock bucket and blob - mock_client.bucket = MagicMock(return_value=mock_bucket) - mock_bucket.blob = MagicMock(return_value=mock_blob) - - # Initialize the downloader - storage_options = {"project": "DUMMY_PROJECT"} - downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options) - local_filepath = os.path.join(tmpdir, "a.txt") - downloader.download_file("gs://random_bucket/a.txt", local_filepath) - - # Assert that the correct methods were called - google_mock.cloud.storage.Client.assert_called_with(**storage_options) - mock_client.bucket.assert_called_with("random_bucket") - mock_bucket.blob.assert_called_with("a.txt") - mock_blob.download_to_filename.assert_called_with(local_filepath) - - -@mock.patch("litdata.streaming.downloader._AZURE_STORAGE_AVAILABLE", True) -def test_azure_downloader(tmpdir, monkeypatch, azure_mock): - mock_blob = MagicMock() - mock_blob_data = MagicMock() - mock_blob.download_blob.return_value = mock_blob_data - service_mock = MagicMock() - service_mock.get_blob_client.return_value = mock_blob - - azure_mock.storage.blob.BlobServiceClient = MagicMock(return_value=service_mock) - - # Initialize the downloader - storage_options = {"project": "DUMMY_PROJECT"} - downloader = AzureDownloader("azure://random_bucket", tmpdir, [], storage_options) - local_filepath = os.path.join(tmpdir, "a.txt") - downloader.download_file("azure://random_bucket/a.txt", local_filepath) - - # Assert that the correct methods were called - azure_mock.storage.blob.BlobServiceClient.assert_called_with(**storage_options) - service_mock.get_blob_client.assert_called_with(container="random_bucket", blob="a.txt") - mock_blob.download_blob.assert_called() - mock_blob_data.readinto.assert_called() - - def test_download_with_cache(tmpdir, monkeypatch): # Create a file to download/cache with open("a.txt", "w") as f: f.write("hello") try: - local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, []) + local_downloader = LocalDownloaderWithCache("file", tmpdir, tmpdir, []) shutil_mock = MagicMock() os_mock = MagicMock() monkeypatch.setattr(shutil, "copy", shutil_mock) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 742f066d..699a39c4 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -23,7 +23,9 @@ def test_src_resolver_s3_connections(monkeypatch, lightning_cloud_mock): auth = login.Auth() auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") - with pytest.raises(RuntimeError, match="`project_id` couldn't be found from the environment variables."): + with pytest.raises( + RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables." + ): resolver._resolve_dir("/teamspace/s3_connections/imagenet") monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") @@ -60,12 +62,12 @@ def test_src_resolver_studios(monkeypatch, lightning_cloud_mock): auth = login.Auth() auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") - with pytest.raises(RuntimeError, match="`cluster_id`"): + with pytest.raises(RuntimeError, match="`LIGHTNING_CLUSTER_ID`"): resolver._resolve_dir("/teamspace/studios/other_studio") monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id") - with pytest.raises(RuntimeError, match="`project_id`"): + with pytest.raises(RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID`"): resolver._resolve_dir("/teamspace/studios/other_studio") monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") @@ -138,17 +140,17 @@ def test_src_resolver_datasets(monkeypatch, lightning_cloud_mock): assert resolver._resolve_dir("s3://bucket_name").url == "s3://bucket_name" - with pytest.raises(RuntimeError, match="`cluster_id`"): + with pytest.raises(RuntimeError, match="`LIGHTNING_CLUSTER_ID`"): resolver._resolve_dir("/teamspace/datasets/imagenet") monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id") - with pytest.raises(RuntimeError, match="`project_id`"): + with pytest.raises(RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID`"): resolver._resolve_dir("/teamspace/datasets/imagenet") monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") - with pytest.raises(RuntimeError, match="`cloud_space_id`"): + with pytest.raises(RuntimeError, match="`LIGHTNING_CLOUD_SPACE_ID`"): resolver._resolve_dir("/teamspace/datasets/imagenet") monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "cloud_space_id") @@ -300,52 +302,54 @@ def print_fn(msg, file=None): def test_assert_dir_is_empty(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory(*args, **kwargs): + return ["a.txt", "b.txt"] + + def mock_empty_list_directory(*args, **kwargs): + return [] + + monkeypatch.setattr(resolver, "list_directory", mock_list_directory) with pytest.raises(RuntimeError, match="The provided output_dir"): resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + monkeypatch.setattr(resolver, "list_directory", mock_empty_list_directory) resolver._assert_dir_is_empty(resolver.Dir(path="/teamspace/...", url="s3://")) def test_assert_dir_has_index_file(monkeypatch): - boto3 = mock.MagicMock() - client_s3_mock = mock.MagicMock() - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory_0(*args, **kwargs): + return [] - with pytest.raises(RuntimeError, match="The provided output_dir"): - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + def mock_list_directory_1(*args, **kwargs): + return ["a.txt", "b.txt"] - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 0, "Contents": []} - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + def mock_list_directory_2(*args, **kwargs): + return ["index.json"] - resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + def mock_does_file_exist_1(*args, **kwargs): + raise Exception({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") # some exception - client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []} + def mock_does_file_exist_2(*args, **kwargs): + return True - def head_object(*args, **kwargs): - import botocore + def mock_remove_file_or_directory(*args, **kwargs): + return - raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") - - client_s3_mock.head_object = head_object - boto3.client.return_value = client_s3_mock - resolver.boto3 = boto3 + monkeypatch.setattr(resolver, "list_directory", mock_list_directory_0) + monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_1) + monkeypatch.setattr(resolver, "remove_file_or_directory", mock_remove_file_or_directory) resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) - boto3.resource.assert_called() + monkeypatch.setattr(resolver, "list_directory", mock_list_directory_2) + monkeypatch.setattr(resolver, "does_file_exist", mock_does_file_exist_2) + + with pytest.raises(RuntimeError, match="The provided output_dir"): + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://")) + + resolver._assert_dir_has_index_file(resolver.Dir(path="/teamspace/...", url="s3://"), mode="overwrite") def test_resolve_dir_absolute(tmp_path, monkeypatch): @@ -365,3 +369,10 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch): link.symlink_to(src) assert link.resolve() == src assert resolver._resolve_dir(str(link)).path == str(src) + + +def test_resolve_dir_unsupported_cloud_provider(monkeypatch, tmp_path): + """Test that the unsupported cloud provider is handled correctly.""" + test_dir = "some-random-cloud-provider://some-random-bucket" + with pytest.raises(ValueError, match="The provided dir_path"): + resolver._resolve_dir(test_dir)