From 2f78ec1ddee615476cb521abdfd75521c0d421e5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 17 Sep 2024 20:41:50 +0200 Subject: [PATCH] fixing docstrings (#374) * fixing docstrings * cleaning --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 7 -- examples/multi_modal/create_labelencoder.py | 6 +- examples/multi_modal/dataloader.py | 44 ++++----- examples/multi_modal/loop.py | 104 +++++--------------- examples/multi_modal/model_arc.py | 16 ++- pyproject.toml | 39 +++++--- src/litdata/processing/data_processor.py | 22 ++--- src/litdata/processing/functions.py | 17 ++-- src/litdata/processing/readers.py | 2 +- src/litdata/processing/utilities.py | 1 - src/litdata/streaming/cache.py | 3 +- src/litdata/streaming/combined.py | 9 +- src/litdata/streaming/config.py | 4 +- src/litdata/streaming/dataloader.py | 4 +- src/litdata/streaming/dataset.py | 2 +- src/litdata/streaming/item_loader.py | 10 +- src/litdata/streaming/reader.py | 2 +- src/litdata/streaming/resolver.py | 1 - src/litdata/streaming/sampler.py | 2 +- src/litdata/streaming/shuffle.py | 6 +- src/litdata/streaming/writer.py | 14 +-- src/litdata/utilities/broadcast.py | 3 +- src/litdata/utilities/dataset_utilities.py | 5 +- src/litdata/utilities/env.py | 10 +- src/litdata/utilities/packing.py | 4 +- src/litdata/utilities/shuffle.py | 7 +- src/litdata/utilities/train_test_split.py | 5 +- tests/conftest.py | 4 +- tests/processing/test_data_processor.py | 3 +- tests/streaming/test_dataset.py | 3 +- 30 files changed, 152 insertions(+), 207 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93cb15a1..104ebe5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,13 +46,6 @@ repos: additional_dependencies: [tomli] #args: ["--write-changes"] # uncomment if you want to get automatic fixing - - repo: https://github.com/PyCQA/docformatter - rev: v1.7.5 - hooks: - - id: docformatter - additional_dependencies: [tomli] - args: ["--in-place"] - - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.2 hooks: diff --git a/examples/multi_modal/create_labelencoder.py b/examples/multi_modal/create_labelencoder.py index 657318bc..83549924 100644 --- a/examples/multi_modal/create_labelencoder.py +++ b/examples/multi_modal/create_labelencoder.py @@ -3,11 +3,7 @@ def create_labelencoder(): - """ - Create a label encoder - Returns: - - """ + """Create a label encoder.""" data = ["Cancelation", "IBAN Change", "Damage Report"] # Create an instance of LabelEncoder label_encoder = LabelEncoder() diff --git a/examples/multi_modal/dataloader.py b/examples/multi_modal/dataloader.py index df28dc2e..c33a3cf7 100644 --- a/examples/multi_modal/dataloader.py +++ b/examples/multi_modal/dataloader.py @@ -29,15 +29,12 @@ def __init__(self): self.hyperparameters = HYPERPARAMETERS def load_labelencoder(self): - """ - Function to load the label encoder from s3 - Returns: - """ + """Function to load the label encoder from s3.""" return joblib.load(self.hyperparameters["label_encoder_name"]) def load_tokenizer(self): - """ - load the tokenizer files and the pre-training model path from s3 specified in the hyperparameters + """Loads the tokenizer files and the pre-training model path from s3 specified in the hyperparameters. + Returns: tokenizer """ # Load Bert tokenizer @@ -62,13 +59,10 @@ def __init__(self, input_dir: Union[str, Any], hyperparameters: Union[dict, Any] self.labelencoder = EC.load_labelencoder() def tokenize_data(self, tokenizer, texts, max_length: int): - """ - Tokenize the text - Args: - tokenizer: - texts: - max_length: - Returns: input_ids, attention_masks + """Tokenize the text. + + Returns: input_ids, attention_masks. + """ encoded_text = tokenizer( texts, @@ -98,11 +92,10 @@ class MixedDataModule(pl.LightningDataModule): """Own DataModule form the pytorch lightning DataModule.""" def __init__(self, hyperparameters: dict): - """ - Init if the Data Module + """Initialize if the Data Module. + Args: - data_path: dataframe with the data - hyperparameters: Hyperparameters + hyperparameters: Hyperparameters. """ super().__init__() self.hyperparameters = hyperparameters @@ -130,10 +123,11 @@ def __init__(self, hyperparameters: dict): ) def train_dataloader(self) -> DataLoader: - """ - Define the training dataloader + """Define the training dataloader. + Returns: - training dataloader + training dataloader. + """ dataset_train = DocumentClassificationDataset( hyperparameters=self.hyperparameters, @@ -150,10 +144,10 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - """ - Define the validation dataloader + """Defines the validation dataloader. + Returns: - validation dataloader + validation dataloader. """ dataset_val = DocumentClassificationDataset( hyperparameters=self.hyperparameters, @@ -169,8 +163,8 @@ def val_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: - """ - Define the test dataloader + """Defines the test dataloader. + Returns: test dataloader """ diff --git a/examples/multi_modal/loop.py b/examples/multi_modal/loop.py index 6104f6e3..2c2de659 100644 --- a/examples/multi_modal/loop.py +++ b/examples/multi_modal/loop.py @@ -77,7 +77,6 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report): mode: train, test or val report_confusion_matrix: sklearn confusion matrix report: sklear classification report - Returns: """ df_cm = pd.DataFrame(report_confusion_matrix) @@ -87,17 +86,7 @@ def save_reports(self, model_dir, mode, report_confusion_matrix, report): logger.info("Confusion Matrix and Classification report are saved.") def save_test_evaluations(self, model_dir, mode, y_pred, y_true, confis, numerical_id_): - """ - Save a pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset - Args: - model_dir: - mode: - y_pred: - y_true: - confis: - numerical_id_: - Returns: - """ + """Save pandas dataframe with prediction and ground truth and identifier (numerical id) of the test dataset.""" df_test = pd.DataFrame() df_test["pred"] = y_pred df_test["confidence"] = confis.max(axis=1) @@ -151,8 +140,7 @@ def forward( """Forward path, calculate the computational graph in the forward direction. Used for train, test and val. - Args: - y: tensor with text data as tokens + Returns: computional graph @@ -160,34 +148,29 @@ def forward( return self.module(x, y, z) def training_step(self, batch: Dict[str, torch.Tensor]) -> Dict: - """ - Call the eval share for training - Args: - batch: tensor + """Call the eval share for training. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ return self._shared_eval_step(batch, "train") def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict: - """ - Call the eval share for validation - Args: - batch: - batch_idx: + """Call the eval share for validation. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ return self._shared_eval_step(batch, "val") def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict: - """ - Call the eval share for test - Args: - batch: - batch_idx: + """Call the eval share for test. + Returns: - dict with loss, outputs and ground_truth + dict with loss, outputs and ground_truth. + """ ret = self._shared_eval_step(batch, "test") self.pred_list.append(ret) @@ -199,6 +182,7 @@ def _shared_eval_step(self, batch: Dict[str, torch.Tensor], mode: str) -> Dict: Args: batch: tensor mode: train, test or val + Returns: dict with loss, outputs and ground_truth @@ -227,14 +211,8 @@ def _shared_eval_step(self, batch: Dict[str, torch.Tensor], mode: str) -> Dict: return {"outputs": out, "loss": loss, "ground_truth": ground_truth, "numerical_id": numerical_id} - def _epoch_end(self, mode: str): - """ - Calculate loss and metricies at end of epoch - Args: - mode: - Returns: - None - """ + def _epoch_end(self, mode: str) -> None: + """Calculate loss and metrics at end of epoch.""" if mode == "val": output = self.val_metrics.compute() self.log_dict(output) @@ -249,14 +227,7 @@ def _epoch_end(self, mode: str): self.test_metrics.reset() def predict(self, batch: Dict[str, torch.Tensor], batch_idx: int = 0, dataloader_idx: int = 0) -> torch.Tensor: - """Model prediction without softmax and argmax to predict class label. - - Args: - outputs: - Returns: - None - - """ + """Model prediction without softmax and argmax to predict class label.""" self.eval() with torch.no_grad(): ids = batch["ID"] @@ -265,51 +236,30 @@ def predict(self, batch: Dict[str, torch.Tensor], batch_idx: int = 0, dataloader return self.forward(ids, atts, img) def on_test_epoch_end(self) -> None: - """ - Calculate the metrics at the end of epoch for test step - Args: - outputs: - Returns: - None - """ + """Calculate the metrics at the end of epoch for test step.""" self._epoch_end("test") - def on_validation_epoch_end(self): - """ - Calculate the metrics at the end of epoch for val step - Args: - outputs: - Returns: - None - """ + def on_validation_epoch_end(self) -> None: + """Calculate the metrics at the end of epoch for val step.""" self._epoch_end("val") - def on_train_epoch_end(self): - """ - Calculate the metrics at the end of epoch for train step - Args: - outputs: - Returns: - None - """ + def on_train_epoch_end(self) -> None: + """Calculate the metrics at the end of epoch for train step.""" self._epoch_end("train") def configure_optimizers(self) -> Any: - """ - Configure the optimizer + """Configure the optimizer. + Returns: optimizer + """ optimizer = AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.hyperparameters["weight_decay"]) scheduler = StepLR(optimizer, step_size=1, gamma=0.1) return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] def configure_callbacks(self) -> Union[Sequence[pl.pytorch.Callback], pl.pytorch.Callback]: - """Configure Early stopping or Model Checkpointing. - - Returns: - - """ + """Configure Early stopping or Model Checkpointing.""" early_stop = EarlyStopping( monitor="val_MulticlassAccuracy", patience=self.hyperparameters["patience"], mode="max" ) diff --git a/examples/multi_modal/model_arc.py b/examples/multi_modal/model_arc.py index 4ea8cddb..8eb57ca4 100644 --- a/examples/multi_modal/model_arc.py +++ b/examples/multi_modal/model_arc.py @@ -41,9 +41,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): """Forward path, calculate the computational graph in the forward direction. Used for train, test and val. - Args: - input_ids - attention_mask + Returns: computional graph @@ -72,9 +70,9 @@ def __init__(self, endpoint_mode: bool, hyperparameters: dict): self.dropout = nn.Dropout(self.hyperparameters["dropout"]) def get_bert_model(self): - """ - Load the pre trained bert model weights - Returns: model + """Load the pre-trained bert model weights. + + Returns: model. """ model = BertModel.from_pretrained("bert-base-cased") return BertClassifier(model) @@ -89,9 +87,9 @@ def forward( validation. Args: - x (torch.Tensor): Tensor with id token - y (torch.Tensor): Tensor with attention tokens. - z (torch.Tensor): Tensor with image. + x: Tensor with id token + y: Tensor with attention tokens. + z: Tensor with image. Returns: torch.Tensor: The output tensor representing the computational graph. diff --git a/pyproject.toml b/pyproject.toml index b78d1cf0..b0883f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,12 @@ ignore-words-list = "cancelation" [tool.ruff] line-length = 120 target-version = "py38" +# Exclude a variety of commonly ignored directories. +exclude = [ + ".git", + "docs", + "src/litdata/utilities/_pytree.py", +] # Enable Pyflakes `E` and `F` codes by default. lint.select = [ "E", "W", # see: https://pypi.org/project/pycodestyle @@ -53,45 +59,54 @@ lint.select = [ "UP", # see: pyupgrade ] lint.extend-select = [ + "D", "I", # see: isort "C4", # see: https://pypi.org/project/flake8-comprehensions "SIM", # see: https://pypi.org/project/flake8-simplify "RET", # see: https://pypi.org/project/flake8-return "PT", # see: https://pypi.org/project/flake8-pytest-style "NPY201", # see: https://docs.astral.sh/ruff/rules/numpy2-deprecation - "RUF100" # yesqa + "RUF100", # yesqa ] lint.ignore = [ "E731", # Do not assign a lambda expression, use a def "S101", # todo: Use of `assert` detected ] -# Exclude a variety of commonly ignored directories. -exclude = [ - ".git", - "docs", - "src/litdata/utilities/_pytree.py", -] lint.ignore-init-module-imports = true +# Unlike Flake8, default to a complexity level of 10. +lint.mccabe.max-complexity = 10 +# Use Google-style docstrings. +lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] -".actions/*" = ["S101", "S310"] -"setup.py" = ["S101", "SIM115"] +"setup.py" = ["D100", "SIM115"] "examples/**" = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D107", # Missing docstring in public module, class, method, function, package + "D205", # todo: 1 blank line required between summary line and description + "D401", "D404", # First line should be in imperative mood; try rephrasing "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] "src/**" = [ + "D100", # Missing docstring in public module + "D101", # todo: Missing docstring in public class + "D102", # todo: Missing docstring in public method + "D103", # todo: Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # todo: Missing docstring in magic method + "D107", # todo: Missing docstring in __init__ + "D205", # todo: 1 blank line required between summary line and description + "D401", "D404", # todo: First line should be in imperative mood; try rephrasing "S602", # todo: `subprocess` call with `shell=True` identified, security issue "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` "S607", # todo: Starting a process with a partial executable path "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. ] "tests/**" = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D107", # Missing docstring in public module, class, method, function, package + "D401", "D404", # First line should be in imperative mood; try rephrasing "S105", "S106", # todo: Possible hardcoded password: ... ] -[tool.ruff.lint.mccabe] -# Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 [tool.mypy] diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index b957941e..f1af9afa 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -119,7 +119,7 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: - """This function is used to download data from a remote directory to a cache directory to optimise reading.""" + """Download data from a remote directory to a cache directory to optimise reading.""" s3 = S3Client() while True: @@ -176,7 +176,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: - """This function is used to delete files from the cache directory to minimise disk space.""" + """Delete files from the cache directory to minimise disk space.""" while True: # 1. Collect paths paths = queue_in.get() @@ -199,7 +199,7 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: - """This function is used to upload optimised chunks from a local to remote dataset directory.""" + """Upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) if obj.scheme == "s3": @@ -787,7 +787,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @abstractmethod def prepare_item(self, item_metadata: T) -> Any: - """The return of this `prepare_item` method is persisted in chunked binary files.""" + """Returns `prepare_item` method is persisted in chunked binary files.""" def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result: num_nodes = _get_num_nodes() @@ -832,7 +832,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul ) def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: Optional[int]) -> None: - """This method upload the index file to the remote cloud directory.""" + """Upload the index file to the remote cloud directory.""" if output_dir.path is None and output_dir.url is None: return @@ -909,10 +909,9 @@ def __init__( item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, ): - """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make - training faster. + """Provides an efficient way to process data across multiple machine into chunks to make training faster. - Arguments: + Args: input_dir: The path to where the input data are stored. output_dir: The path to where the output data are stored. num_workers: The number of worker threads to use. @@ -984,7 +983,7 @@ def __init__( self.random_seed = random_seed def run(self, data_recipe: DataRecipe) -> None: - """The `DataProcessor.run(...)` method triggers the data recipe processing over your dataset.""" + """Triggers the data recipe processing over your dataset.""" if not isinstance(data_recipe, DataRecipe): raise ValueError("The provided value should be a data recipe.") if not self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): @@ -1392,6 +1391,7 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: def in_notebook() -> bool: - """Returns ``True`` if the module is running in IPython kernel, ``False`` if in IPython shell or other Python - shell.""" + """Returns ``True`` if the module is running in IPython kernel, ``False`` if in IPython or other Python + shell. + """ return "ipykernel" in sys.modules diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 8edc2319..24b7564d 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -180,7 +180,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> Any: return self._inputs def prepare_item(self, item_metadata: Any) -> Any: - """This method is overridden dynamically.""" + """Being overridden dynamically.""" def map( @@ -200,9 +200,9 @@ def map( reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, ) -> None: - """This function maps a callable over a collection of inputs, possibly in a distributed way. + """Maps a callable over a collection of inputs, possibly in a distributed way. - Arguments: + Args: fn: A function to be executed over each input element inputs: A sequence of input to be processed by the `fn` function, or a streaming dataloader. output_dir: The folder where the processed data should be stored. @@ -218,6 +218,7 @@ def map( reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. error_when_not_empty: Whether we should error if the output folder isn't empty. + reader: The reader to use when reading the data. By default, it uses the `BaseReader`. batch_size: Group the inputs into batches of batch_size length. """ @@ -317,7 +318,7 @@ def optimize( ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. - Arguments: + Args: fn: A function to be executed over each input element. The function should return the data sample that corresponds to the input. Every invocation of the function should return a similar hierarchy of objects, where the object types and list sizes don't change. @@ -336,6 +337,7 @@ def optimize( machine: When doing remote execution, the machine to use. Only supported on https://lightning.ai/. num_downloaders: The number of downloaders per worker. num_uploaders: The numbers of uploaders per worker. + reader: The reader to use when reading the data. By default, it uses the `BaseReader`. reorder_files: By default, reorders the files by file size to distribute work equally among all workers. Set this to ``False`` if the order in which samples are processed should be preserved. batch_size: Group the inputs into batches of batch_size length. @@ -468,7 +470,7 @@ def _listdir(folder: str) -> Tuple[str, List[str]]: class walk: """This class is an optimized version of os.walk for listing files and folders from cloud filesystem. - Note: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call. + .. note:: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call. """ @@ -520,10 +522,9 @@ class CopyInfo: def merge_datasets(input_dirs: List[str], output_dir: str) -> None: - """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized - dataset. + """Enables to merge multiple existing optimized datasets into a single optimized dataset. - Arguments: + Args: input_dirs: A list of directories pointing to the existing optimized datasets. output_dir: The directory where the merged dataset would be stored. diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index 80ae16bb..beec5664 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -33,7 +33,7 @@ def get_node_rank(self) -> int: @abstractmethod def remap_items(self, items: Any, num_workers: int) -> List[Any]: - """This method is meant to remap the items provided by the users into items more adapted to be distributed.""" + """Remap the items provided by the users into items more adapted to be distributed.""" @abstractmethod def read(self, item: Any) -> Any: diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 87a5b3e7..9547ec66 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -251,7 +251,6 @@ def remove_uuid_from_filename(filepath: str) -> str: -> `checkpoint-0.json` """ - if not filepath.__contains__(".checkpoints"): return filepath diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 816ef816..d045d5da 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -51,11 +51,12 @@ def __init__( """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements together in order to accelerate fetching. - Arguments: + Args: input_dir: The path to where the chunks will be or are stored. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of (start,end) of region of interest for each chunk. compression: The name of the algorithm to reduce the size of the chunks. + encryption: The encryption algorithm to use. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. item_loader: The object responsible to generate the chunk intervals and load an item froma chunk. diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index f3fd0e9b..22149589 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -25,7 +25,7 @@ class CombinedStreamingDataset(IterableDataset): - """The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of + """Enables to stream data from multiple StreamingDataset with the sampling ratio of your choice. Additionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable reusability @@ -43,15 +43,16 @@ def __init__( weights: Optional[Sequence[float]] = None, iterate_over_all: bool = True, ) -> None: - """ " - Arguments: + """Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice. + + Args: datasets: The list of the StreamingDataset to use. seed: The random seed to initialize the sampler weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. - """ + """ self._check_datasets(datasets) self._seed = seed diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 997d2a83..df0ea012 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -35,14 +35,14 @@ def __init__( region_of_interest: Optional[List[Tuple[int, int]]] = None, storage_options: Optional[Dict] = {}, ) -> None: - """The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its - chunk. + """Reads the index files associated a chunked dataset and enables to map an index to its chunk. Arguments: cache_dir: The path to cache folder. serializers: The serializers used to serialize and deserialize the chunks. remote_dir: The path to a remote folder where the data are located. The scheme needs to be added to the path. + item_loader: The item loader used to load the data from the chunks. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of {start,end} of region of interest for each chunk. storage_options: Additional connection options for accessing storage services. diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 3f932fae..3f0bf584 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -80,11 +80,11 @@ def __init__( ): """The `CacheDataset` is a dataset wrapper to provide a beginner experience with the Cache. - Arguments: + Args: dataset: The dataset of the user cache_dir: The folder where the chunks are written to. chunk_bytes: The maximal number of bytes to write within a chunk. - chunk_sie: The maximal number of items to write to a chunk. + chunk_size: The maximal number of items to write to a chunk. compression: The compression algorithm to use to reduce the size of the chunk. """ diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 5c57cd69..df2228aa 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -59,7 +59,7 @@ def __init__( ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. - Arguments: + Args: input_dir: Path to the folder where the input data is stored. item_loader: The logic to load an item from a chunk. shuffle: Whether to shuffle the data. diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 7c28595f..acc50d2b 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -70,11 +70,11 @@ def state_dict(self) -> Dict: @abstractmethod def generate_intervals(self) -> List[Interval]: - """Returns a list of intervals: [chunk_start, - region_of_interest_start, region_of_interest_end, chunk_end] + """Returns a list of intervals. - region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read. + The structure is: [chunk_start, region_of_interest_start, region_of_interest_end, chunk_end] + region_of_interest: indicates the indexes a chunk our StreamingDataset is allowed to read. """ @abstractmethod @@ -164,7 +164,6 @@ def _load_encrypted_data( self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Optional[Encryption] ) -> bytes: """Load and decrypt data from chunk based on the encryption configuration.""" - # Validate the provided encryption object against the expected configuration. self._validate_encryption(encryption) @@ -257,11 +256,10 @@ class TokensLoader(BaseItemLoader): def __init__(self, block_size: Optional[int] = None): """The Tokens Loader is an optimizer item loader for NLP. - Arguments: + Args: block_size: The context length to use during training. """ - super().__init__() self._block_size = block_size self._mmaps: Dict[int, np.memmap] = {} diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index f638df55..60cd2965 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -170,7 +170,7 @@ def __init__( ) -> None: """The BinaryReader enables to read chunked dataset in an efficient way. - Arguments: + Args: cache_dir: The path to cache folder. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of {start,end} of region of interest for each chunk. diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 874497a6..1e388745 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -348,7 +348,6 @@ def _execute( command: Optional[str] = None, ) -> None: """Remotely execute the current operator.""" - if not _LIGHTNING_SDK_AVAILABLE: raise ModuleNotFoundError("The `lightning_sdk` is required.") diff --git a/src/litdata/streaming/sampler.py b/src/litdata/streaming/sampler.py index f25337d6..d135979b 100644 --- a/src/litdata/streaming/sampler.py +++ b/src/litdata/streaming/sampler.py @@ -45,7 +45,7 @@ def __init__( If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset If the cache is filled, it acts as normal BatchSampler. - Arguments: + Args: dataset_size: The size of the dataset. num_replicas: The number of processes involves in the distributed training. global_rank: The global_rank of the given process diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 752b21d0..90e1085b 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -59,7 +59,8 @@ def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk class NoShuffle(Shuffle): """NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items if drop_last - is True.""" + is True. + """ @lru_cache(maxsize=10) def get_chunks_and_intervals_per_workers( @@ -83,13 +84,10 @@ class FullShuffle(Shuffle): """FullShuffle shuffles the chunks and associates them to the ranks. As the number of items in a chunk varies, it is possible for a rank to end up with more or less items. - To ensure the same fixed dataset length for all ranks while dropping as few items as possible, - we adopt the following strategy. We compute the maximum number of items per rank (M) and iterate through the chunks and ranks - until we have associated at least M items per rank. As a result, we lose at most (number of ranks) items. However, as some chunks are shared across ranks. This leads to diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 25dded88..10f1f11f 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -58,14 +58,16 @@ def __init__( ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. - Arguments: + Args: cache_dir: The path to where the chunks will be saved. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. compression: The compression algorithm to use. encryption: The encryption algorithm to use. + follow_tensor_dimension: Whether to follow the tensor dimension when serializing the data. serializers: Provide your own serializers. chunk_index: The index of the chunk to start from. + item_loader: The object responsible to generate the chunk intervals and load an item from a chunk. """ self._cache_dir = cache_dir @@ -155,7 +157,6 @@ def get_config(self) -> Dict[str, Any]: def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]: """Serialize a dictionary into its binary format.""" - # Flatten the items provided by the users flattened, data_spec = tree_flatten(items) @@ -289,8 +290,8 @@ def __setitem__(self, index: int, items: Any) -> None: def add_item(self, index: int, items: Any) -> Optional[str]: """Given an index and items will serialize the items and store an Item object to the growing - `_serialized_items`.""" - + `_serialized_items`. + """ if index in self._serialized_items: raise ValueError(f"The provided index {index} already exists in the cache.") @@ -413,7 +414,8 @@ def done(self) -> List[str]: def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None: """Once all the workers have written their own index, the merge function is responsible to read and merge them - into a single index.""" + into a single index. + """ num_workers = num_workers or 1 # Only for non rank 0 @@ -443,7 +445,7 @@ def _merge_no_wait(self, node_rank: Optional[int] = None, existing_index: Option """Once all the workers have written their own index, the merge function is responsible to read and merge them into a single index. - Arguments: + Args: node_rank: The node rank of the index file existing_index: Existing index to be added to the newly created one. diff --git a/src/litdata/utilities/broadcast.py b/src/litdata/utilities/broadcast.py index 770aabb3..8d94c5c0 100644 --- a/src/litdata/utilities/broadcast.py +++ b/src/litdata/utilities/broadcast.py @@ -48,8 +48,7 @@ def _response(r: Any, *args: Any, **kwargs: Any) -> Any: class _HTTPClient: - """A wrapper class around the requests library which handles chores like logging, retries, and timeouts - automatically.""" + """A wrapper around the requests library which handles chores like logging, retries, and timeouts automatically.""" def __init__( self, diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index ca43bedd..0c482074 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -161,7 +161,7 @@ def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Di def generate_roi(chunks: List[Dict[str, Any]], item_loader: Optional[BaseItemLoader] = None) -> List[Tuple[int, int]]: - "Generates default region_of_interest for chunks." + """Generates default region_of_interest for chunks.""" roi = [] if isinstance(item_loader, TokensLoader): @@ -207,13 +207,14 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]: """Adapt mds shard-based index data to chunk-based format for compatibility. - For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming + For more details about MDS, refer to the MosaicML Streaming documentation: https://github.com/mosaicml/streaming. Args: data (Dict[str, Any]): The original index data containing shards. Returns: Dict[str, Any]: Adapted index data with chunks format. + """ chunks = [] shards = data["shards"] diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index b101f57f..74f72c46 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -36,7 +36,7 @@ def __init__(self, world_size: int, global_rank: int, num_nodes: int): def detect(cls) -> "_DistributedEnv": """Tries to automatically detect the distributed environment parameters. - Note: + .. note:: This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) as the distributed framework won't be initialized there. It will default to 1 distributed process in this case. @@ -103,7 +103,7 @@ def __init__(self, world_size: int, rank: int): def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv": """Automatically detects the number of workers and the current rank. - Note: + .. note:: This only works reliably within a dataloader worker as otherwise the necessary information won't be present. In such a case it will default to 1 worker @@ -146,7 +146,7 @@ def from_args( """Generates the Environment class by already given arguments instead of detecting them. Args: - dist_world_size: The worldsize used for distributed training (=total number of distributed processes) + dist_world_size: The world-size used for distributed training (=total number of distributed processes) global_rank: The distributed global rank of the current process num_workers: The number of workers per distributed training process current_worker_rank: The rank of the current worker within the number of workers of @@ -162,7 +162,7 @@ def from_args( def num_shards(self) -> int: """Returns the total number of shards. - Note: + .. note:: This may not be accurate in a non-dataloader-worker process like the main training process as it doesn't necessarily know about the number of dataloader workers. @@ -175,7 +175,7 @@ def num_shards(self) -> int: def shard_rank(self) -> int: """Returns the rank of the current process wrt. the total number of shards. - Note: + .. note:: This may not be accurate in a non-dataloader-worker process like the main training process as it doesn't necessarily know about the number of dataloader workers. diff --git a/src/litdata/utilities/packing.py b/src/litdata/utilities/packing.py index 3b1c8480..0c11d9a7 100644 --- a/src/litdata/utilities/packing.py +++ b/src/litdata/utilities/packing.py @@ -17,8 +17,8 @@ def _pack_greedily(items: List[Any], weights: List[int], num_bins: int) -> Tuple[Dict[int, List[Any]], Dict[int, int]]: """Greedily pack items with given weights into bins such that the total weight of each bin is roughly equally - distributed among all bins.""" - + distributed among all bins. + """ if len(items) != len(weights): raise ValueError(f"Items and weights must have the same length, got {len(items)} and {len(weights)}.") if any(w <= 0 for w in weights): diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 9133ae05..f1861fe9 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -51,7 +51,8 @@ def _group_chunks_by_nodes( num_workers_per_process: int, ) -> List[List[int]]: """Takes a list representing chunks grouped by worker (global worker id across ranks and nodes) and returns a list - in which the chunks are grouped by node.""" + in which the chunks are grouped by node. + """ chunk_indexes_per_nodes: Any = [[] for _ in range(num_nodes)] num_processes_per_node = world_size // num_nodes for worker_global_id, chunks in enumerate(chunks_per_workers): @@ -135,7 +136,6 @@ def _find_chunks_per_workers_on_which_to_skip_deletion( to only let the worker delete a chunk when that worker is the last to read from it. """ - # Shared chunks across all workers and ranks shared_chunks = _get_shared_chunks(workers_chunks) @@ -279,7 +279,8 @@ def _aggregate_shared_chunks_per_rank( def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable: Dict[int, List[int]]) -> Dict[int, List[int]]: """Takes a dictionary mapping a chunk index to a list of workers and inverts the map such that it returns a - dictionary mapping a worker to a list of chunk indexes (that should not be deleted by that worker).""" + dictionary mapping a worker to a list of chunk indexes (that should not be deleted by that worker). + """ map_node_worker_rank_to_chunk_indexes: Dict[int, List[int]] = {} for chunk_index, worker_ids in to_disable.items(): for worker_idx in worker_ids: diff --git a/src/litdata/utilities/train_test_split.py b/src/litdata/utilities/train_test_split.py index d31fb076..a44b8b5c 100644 --- a/src/litdata/utilities/train_test_split.py +++ b/src/litdata/utilities/train_test_split.py @@ -20,9 +20,10 @@ def train_test_split( These subsets can be used for training, testing, and validation purposes. Args: - streaming_dataset (StreamingDataset): An instance of StreamingDataset that needs to be split. - splits (List[float]): A list of floats representing the proportion of data to be allocated to each split + streaming_dataset: An instance of StreamingDataset that needs to be split. + splits: A list of floats representing the proportion of data to be allocated to each split (e.g., [0.8, 0.1, 0.1] for 80% training, 10% testing, and 10% validation). + seed: An integer used to seed the random number generator for reproducibility. Returns: List[StreamingDataset]: A list of StreamingDataset instances, where each element represents a split of the diff --git a/tests/conftest.py b/tests/conftest.py index 32a86645..a4b1f707 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ @pytest.fixture(autouse=True) def teardown_process_group(): # noqa: PT004 - """Ensures that the distributed process group gets closed before the next test runs.""" + """Ensures distributed process group gets closed before the next test runs.""" yield if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() @@ -84,7 +84,7 @@ def lightning_sdk_mock(monkeypatch): @pytest.fixture(autouse=True) def _thread_police(): - """Attempts to stop left-over threads to avoid test interactions. + """Attempts stopping left-over threads to avoid test interactions. Adapted from PyTorch Lightning. diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 543b6909..cc3ef9b2 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -460,8 +460,7 @@ def _broadcast_object(self, obj: Any) -> Any: condition=(not _PIL_AVAILABLE or sys.platform == "win32" or sys.platform == "linux"), reason="Requires: ['pil']" ) def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, monkeypatch): - """This test ensures the data optimizer works in a fully distributed settings.""" - + """Ensures the data optimizer works in a fully distributed settings.""" seed_everything(42) monkeypatch.setattr(data_processor_module.os, "_exit", mock.MagicMock()) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 78cb0eaa..ef93021c 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -892,8 +892,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): @pytest.mark.timeout(60) @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): - """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have - the same size.""" + """Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") optimize_data_cache_dir = str(tmpdir / "optimize_data_cache") optimize_cache_dir = str(tmpdir / "optimize_cache")