diff --git a/README.md b/README.md index 312f3f68..decea184 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@     -**Transform datasets at scale. +**Transform datasets at scale. Optimize data for fast AI model training.** @@ -45,20 +45,20 @@ Transform Optimize   -# Transform data at scale. Optimize for fast model training. +# Transform data at scale. Optimize for fast model training. LitData scales [data processing tasks](#transform-datasets) (data scraping, image resizing, distributed inference, embedding creation) on local or cloud machines. It also enables [optimizing datasets](#speed-up-model-training) to accelerate AI model training and work with large remote datasets without local loading.   # Quick start -First, install LitData: +First, install LitData: ```bash pip install litdata -``` +``` + +Choose your workflow: -Choose your workflow: - πŸš€ [Speed up model training](#speed-up-model-training) πŸš€ [Transform datasets](#transform-datasets) @@ -72,7 +72,7 @@ Install all the extras pip install 'litdata[extras]' ``` - +   @@ -81,25 +81,26 @@ pip install 'litdata[extras]' # Speed up model training Accelerate model training (20x faster) by optimizing datasets for streaming directly from cloud storage. Work with remote data without local downloads with features like loading data subsets, accessing individual samples, and resumable streaming. -**Step 1: Optimize the data** -This step will format the dataset for fast loading (binary, chunked, etc...) +**Step 1: Optimize the data** +This step will format the dataset for fast loading. The data will be written in a chunked binary format. ```python import numpy as np from PIL import Image import litdata as ld - + def random_images(index): fake_images = Image.fromarray(np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)) fake_labels = np.random.randint(10) - # use any key:value pairs + # You can use any key:value pairs. Note that their types must not change between samples, and Python lists must + # always contain the same number of elements with the same types. data = {"index": index, "image": fake_images, "class": fake_labels} return data if __name__ == "__main__": - # the optimize function outputs data in an optimized format (chunked, binerized, etc...) + # The optimize function writes data in an optimized format. ld.optimize( fn=random_images, # the function applied to each input inputs=list(range(1000)), # the inputs to the function (here it's a list of numbers) @@ -107,16 +108,16 @@ if __name__ == "__main__": num_workers=4, # The number of workers on the same machine chunk_bytes="64MB" # size of each chunk ) -``` +``` **Step 2: Put the data on the cloud** -Upload the data to a [Lightning Studio](https://lightning.ai) (backed by S3) or your own S3 bucket: +Upload the data to a [Lightning Studio](https://lightning.ai) (backed by S3) or your own S3 bucket: ```bash aws s3 cp --recursive my_optimized_dataset s3://my-bucket/my_optimized_dataset -``` +``` -**Step 3: Stream the data during training** +**Step 3: Stream the data during training** Load the data by replacing the PyTorch DataSet and DataLoader with the StreamingDataset and StreamingDataloader @@ -143,10 +144,10 @@ for sample in dataloader:   ----- +---- -# Transform datasets -Accelerate data processing tasks (data scraping, image resizing, embedding creation, distributed inference) by parallelizing (map) the work across many machines at once. +# Transform datasets +Accelerate data processing tasks (data scraping, image resizing, embedding creation, distributed inference) by parallelizing (map) the work across many machines at once. Here's an example that resizes and crops a large image dataset: @@ -154,7 +155,7 @@ Here's an example that resizes and crops a large image dataset: from PIL import Image import litdata as ld -# use a local or S3 folder +# use a local or S3 folder input_dir = "my_large_images" # or "s3://my-bucket/my_large_images" output_dir = "my_resized_images" # or "s3://my-bucket/my_resized_images" @@ -164,10 +165,10 @@ inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)] def resize_image(image_path, output_dir): output_image_path = os.path.join(output_dir, os.path.basename(image_path)) Image.open(image_path).resize((224, 224)).save(output_image_path) - + ld.map( fn=resize_image, - inputs=inputs, + inputs=inputs, output_dir="output_dir", ) ``` @@ -186,18 +187,18 @@ ld.map( # Key Features -## Features for optimizing and streaming datasets for model training +## Features for optimizing and streaming datasets for model training
βœ… Stream large cloud datasets   -Use data stored on the cloud without needing to download it all to your computer, saving time and space. +Use data stored on the cloud without needing to download it all to your computer, saving time and space. Imagine you're working on a project with a huge amount of data stored online. Instead of waiting hours to download it all, you can start working with the data almost immediately by streaming it. -Once you've optimized the dataset with LitData, stream it as follows: +Once you've optimized the dataset with LitData, stream it as follows: ```python from litdata import StreamingDataset, StreamingDataLoader @@ -224,7 +225,7 @@ storage_options = { dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) ``` -
+
βœ… Streams on multi-GPU, multi-node @@ -233,13 +234,13 @@ dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_opt Data optimized and loaded with Lightning automatically streams efficiently in distributed training across GPUs or multi-node. -The `StreamingDataset` and `StreamingDataLoader` automatically make sure each rank receives the same quantity of varied batches of data, so it works out of the box with your favorite frameworks ([PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), or [PyTorch](https://pytorch.org/docs/stable/index.html)) to do distributed training. +The `StreamingDataset` and `StreamingDataLoader` automatically make sure each rank receives the same quantity of varied batches of data, so it works out of the box with your favorite frameworks ([PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), or [PyTorch](https://pytorch.org/docs/stable/index.html)) to do distributed training. Here you can see an illustration showing how the Streaming Dataset works with multi node / multi gpu under the hood. ![An illustration showing how the Streaming Dataset works with multi node.](https://pl-flash-data.s3.amazonaws.com/streaming_dataset.gif) -
+
βœ… Stream from multiple cloud providers @@ -300,13 +301,13 @@ if os.path.isfile("dataloader_state.pt"): # Iterate over the data for batch_idx, batch in enumerate(dataloader): - + # Store the state every 1000 batches if batch_idx % 1000 == 0: torch.save(dataloader.state_dict(), "dataloader_state.pt") ``` -
+
@@ -315,7 +316,7 @@ for batch_idx, batch in enumerate(dataloader): Mix and match different sets of data to experiment and create better models. -Combine datasets with `CombinedStreamingDataset`. As an example, this mixture of [Slimpajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) & [StarCoder](https://huggingface.co/datasets/bigcode/starcoderdata) was used in the [TinyLLAMA](https://github.com/jzhang38/TinyLlama) project to pretrain a 1.1B Llama model on 3 trillion tokens. +Combine datasets with `CombinedStreamingDataset`. As an example, this mixture of [Slimpajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) & [StarCoder](https://huggingface.co/datasets/bigcode/starcoderdata) was used in the [TinyLLAMA](https://github.com/jzhang38/TinyLlama) project to pretrain a 1.1B Llama model on 3 trillion tokens. ```python from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader, TokensLoader @@ -325,13 +326,13 @@ import os train_datasets = [ StreamingDataset( input_dir="s3://tinyllama-template/slimpajama/train/", - item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs + item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs shuffle=True, drop_last=True, ), StreamingDataset( input_dir="s3://tinyllama-template/starcoder/", - item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs + item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs shuffle=True, drop_last=True, ), @@ -339,7 +340,7 @@ train_datasets = [ # Mix SlimPajama data and Starcoder data with these proportions: weights = (0.693584, 0.306416) -combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights) +combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights, iterate_over_all=False) train_dataloader = StreamingDataLoader(combined_dataset, batch_size=8, pin_memory=True, num_workers=os.cpu_count()) @@ -347,7 +348,7 @@ train_dataloader = StreamingDataLoader(combined_dataset, batch_size=8, pin_memor for batch in tqdm(train_dataloader): pass ``` -
+
βœ… Split datasets for train, val, test @@ -376,13 +377,13 @@ print(test_dataset) #Β out: 50,000 ``` -
+
βœ… Load a subset of the remote dataset   -Work on a smaller, manageable portion of your data to save time and resources. +Work on a smaller, manageable portion of your data to save time and resources. ```python @@ -394,7 +395,7 @@ print(len(dataset)) # display the length of your data #Β out: 1000 ``` -
+
βœ… Easily modify optimized cloud datasets @@ -435,13 +436,13 @@ if __name__ == "__main__": The `overwrite` mode will delete the existing data and start from fresh. -
+
βœ… Access samples without full data download   -Look at specific parts of a large dataset without downloading the whole thing or loading it on a local machine. +Look at specific parts of a large dataset without downloading the whole thing or loading it on a local machine. ```python from litdata import StreamingDataset @@ -453,7 +454,7 @@ print(len(dataset)) # display the length of your data print(dataset[42]) # show the 42th element of the dataset ``` -
+
βœ… Use any data transforms @@ -481,13 +482,13 @@ for batch in dataloader: # Out: (4, 3, 224, 224) ``` -
+
βœ… Profile data loading speed   -Measure and optimize how fast your data is being loaded, improving efficiency. +Measure and optimize how fast your data is being loaded, improving efficiency. The `StreamingDataLoader` supports profiling of your data loading process. Simply use the `profile_batches` argument to specify the number of batches you want to profile: @@ -499,7 +500,7 @@ StreamingDataLoader(..., profile_batches=5) This generates a Chrome trace called `result.json`. Then, visualize this trace by opening Chrome browser at the `chrome://tracing` URL and load the trace inside. -
+
βœ… Reduce memory use for large files @@ -507,7 +508,7 @@ This generates a Chrome trace called `result.json`. Then, visualize this trace b Handle large data files efficiently without using too much of your computer's memory. -When processing large files like compressed [parquet files](https://en.wikipedia.org/wiki/Apache_Parquet), use the Python yield keyword to process and store one item at the time, reducing the memory footprint of the entire program. +When processing large files like compressed [parquet files](https://en.wikipedia.org/wiki/Apache_Parquet), use the Python yield keyword to process and store one item at the time, reducing the memory footprint of the entire program. ```python from pathlib import Path @@ -537,13 +538,13 @@ outputs = optimize( ) ``` -
+
βœ… Limit local cache space   -Limit the amount of disk space used by temporary files, preventing storage issues. +Limit the amount of disk space used by temporary files, preventing storage issues. Adapt the local caching limit of the `StreamingDataset`. This is useful to make sure the downloaded data chunks are deleted when used and the disk usage stays low. @@ -553,7 +554,7 @@ from litdata import StreamingDataset dataset = StreamingDataset(..., max_cache_size="10GB") ``` -
+
βœ… Change cache directory path @@ -578,7 +579,7 @@ dataset = StreamingDataset(input_dir=Dir(path=cache_dir, url=data_dir))   Optimize data handling for computers on a local network to improve performance for on-site setups. - + On-prem compute nodes can mount and use a network drive. A network drive is a shared storage device on a local area network. In order to reduce their network overload, the `StreamingDataset` supports `caching` the data chunks. ```python @@ -629,11 +630,11 @@ dataset = StreamingDataset(output_dir) print(dataset[:]) ``` -
+   -## Features for transforming datasets +## Features for transforming datasets
βœ… Parallelize data transformations (map) @@ -653,28 +654,28 @@ from PIL import Image input_dir = "my_large_images" inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)] -#Β The resize image takes one of the input (image_path) and the output directory. +#Β The resize image takes one of the input (image_path) and the output directory. # Files written to output_dir are persisted. def resize_image(image_path, output_dir): output_image_path = os.path.join(output_dir, os.path.basename(image_path)) Image.open(image_path).resize((224, 224)).save(output_image_path) - + map( fn=resize_image, - inputs=inputs, + inputs=inputs, output_dir="s3://my-bucket/my_resized_images", ) ``` -
+
βœ… Support S3-Compatible cloud object storage   -Use different cloud storage services, offering data storage flexibility and cost-saving options. +Use different cloud storage services, offering data storage flexibility and cost-saving options. -Integrate S3-compatible object storage servers like [MinIO](https://min.io/) with litdata, ideal for on-premises infrastructure setups. Configure the endpoint and credentials using environment variables or configuration files. +Integrate S3-compatible object storage servers like [MinIO](https://min.io/) with litdata, ideal for on-premises infrastructure setups. Configure the endpoint and credentials using environment variables or configuration files. Set up the environment variables to connect to MinIO: @@ -701,7 +702,7 @@ EOL ``` Explore an example setup of litdata with MinIO in the [LitData with MinIO](https://github.com/bhimrazy/litdata-with-minio) repository for practical implementation details. -
+
βœ… Supports encryption and decryption of data at chunk/sample level @@ -777,13 +778,13 @@ With this setup, you can ensure that your data remains secure while maintaining ---- # Benchmarks -In this section we show benchmarks for speed to optimize a dataset and the resulting streaming speed ([Reproduce the benchmark](https://lightning.ai/lightning-ai/studios/benchmark-cloud-data-loading-libraries)). +In this section we show benchmarks for speed to optimize a dataset and the resulting streaming speed ([Reproduce the benchmark](https://lightning.ai/lightning-ai/studios/benchmark-cloud-data-loading-libraries)). ## Streaming speed -Data optimized and streamed with LitData achieves a 20x speed up over non optimized data and 2x speed up over other streaming solutions. +Data optimized and streamed with LitData achieves a 20x speed up over non optimized data and 2x speed up over other streaming solutions. -Speed to stream Imagenet 1.2M from AWS S3: +Speed to stream Imagenet 1.2M from AWS S3: | Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | |---|---|---|---|---| @@ -795,17 +796,17 @@ Speed to stream Imagenet 1.2M from AWS S3: Benchmark details   -- [Imagenet-1.2M dataset](https://www.image-net.org/) contains `1,281,167 images`. -- To align with other benchmarks, we measured the streaming speed (`images per second`) loaded from [AWS S3](https://aws.amazon.com/s3/) for several frameworks. +- [Imagenet-1.2M dataset](https://www.image-net.org/) contains `1,281,167 images`. +- To align with other benchmarks, we measured the streaming speed (`images per second`) loaded from [AWS S3](https://aws.amazon.com/s3/) for several frameworks. -
+ -  +  -## Time to optimize data -LitData optimizes the Imagenet dataset for fast training 3-5x faster than other frameworks: +## Time to optimize data +LitData optimizes the Imagenet dataset for fast training 3-5x faster than other frameworks: -Time to optimize 1.2 million ImageNet images (Faster is better): +Time to optimize 1.2 million ImageNet images (Faster is better): | Framework |Train Conversion Time | Val Conversion Time | Dataset Size | # Files | |---|---|---|---|---| | PL Data | **10:05 min** | **00:30 min** | **143.1 GB** | 2.339 | @@ -816,16 +817,16 @@ Time to optimize 1.2 million ImageNet images (Faster is better): ---- -# Parallelize transforms and data optimization on cloud machines +# Parallelize transforms and data optimization on cloud machines
Lightning -
+ + +## Parallelize data transforms -## Parallelize data transforms +Transformations with LitData are linearly parallelizable across machines. -Transformations with LitData are linearly parallelizable across machines. - -For example, let's say that it takes 56 hours to embed a dataset on a single A10G machine. With LitData, +For example, let's say that it takes 56 hours to embed a dataset on a single A10G machine. With LitData, this can be speed up by adding more machines in parallel | Number of machines | Hours | @@ -836,7 +837,7 @@ this can be speed up by adding more machines in parallel | ... | ... | | 64 | 0.875 | -To scale the number of machines, run the processing script on [Lightning Studios](https://lightning.ai/): +To scale the number of machines, run the processing script on [Lightning Studios](https://lightning.ai/): ```python from litdata import map, Machine @@ -848,8 +849,8 @@ map( ) ``` -## Parallelize data optimization -To scale the number of machines for data optimization, use [Lightning Studios](https://lightning.ai/): +## Parallelize data optimization +To scale the number of machines for data optimization, use [Lightning Studios](https://lightning.ai/): ```python from litdata import optimize, Machine @@ -863,26 +864,26 @@ optimize(   -Example: [Process the LAION 400 million image dataset in 2 hours on 32 machines, each with 32 CPUs](https://lightning.ai/lightning-ai/studios/use-or-explore-laion-400million-dataset). +Example: [Process the LAION 400 million image dataset in 2 hours on 32 machines, each with 32 CPUs](https://lightning.ai/lightning-ai/studios/use-or-explore-laion-400million-dataset).   ---- -# Start from a template -Below are templates for real-world applications of LitData at scale. +# Start from a template +Below are templates for real-world applications of LitData at scale. -## Templates: Transform datasets +## Templates: Transform datasets -| Studio | Data type | Time (minutes) | Machines | Dataset | +| Studio | Data type | Time (minutes) | Machines | Dataset | | ------------------------------------ | ----------------- | ----------------- | -------------- | -------------- | | [Download LAION-400MILLION dataset](https://lightning.ai/lightning-ai/studios/use-or-explore-laion-400million-dataset) | Image & Text | 120 | 32 |[LAION-400M](https://laion.ai/blog/laion-400-open-dataset/) | | [Tokenize 2M Swedish Wikipedia Articles](https://lightning.ai/lightning-ai/studios/tokenize-2m-swedish-wikipedia-articles) | Text | 7 | 4 | [Swedish Wikipedia](https://huggingface.co/datasets/wikipedia) | | [Embed English Wikipedia under 5 dollars](https://lightning.ai/lightning-ai/studios/embed-english-wikipedia-under-5-dollars) | Text | 15 | 3 | [English Wikipedia](https://huggingface.co/datasets/wikipedia) | -## Templates: Optimize + stream data +## Templates: Optimize + stream data -| Studio | Data type | Time (minutes) | Machines | Dataset | +| Studio | Data type | Time (minutes) | Machines | Dataset | | -------------------------------- | ----------------- | ----------------- | -------------- | -------------- | | [Benchmark cloud data-loading libraries](https://lightning.ai/lightning-ai/studios/benchmark-cloud-data-loading-libraries) | Image & Label | 10 | 1 | [Imagenet 1M](https://paperswithcode.com/sota/image-classification-on-imagenet?tag_filter=171) | | [Optimize GeoSpatial data for model training](https://lightning.ai/lightning-ai/studios/convert-spatial-data-to-lightning-streaming) | Image & Mask | 120 | 32 | [Chesapeake Roads Spatial Context](https://github.com/isaaccorley/chesapeakersc) | diff --git a/src/litdata/__about__.py b/src/litdata/__about__.py index 67950e46..8270919b 100644 --- a/src/litdata/__about__.py +++ b/src/litdata/__about__.py @@ -14,7 +14,7 @@ import time -__version__ = "0.2.17" +__version__ = "0.2.18" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index b5bb0fad..2d876206 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -508,7 +508,8 @@ def _set_environ_variables(self) -> None: def _create_cache(self) -> None: self.cache_data_dir = _get_cache_data_dir() - os.makedirs(self.cache_data_dir, exist_ok=True) + if not os.path.exists(self.cache_data_dir): # redundant but, otherwise it fails in CI on macOS + os.makedirs(self.cache_data_dir, exist_ok=True) self.cache_chunks_dir = _get_cache_dir() @@ -960,7 +961,9 @@ def run(self, data_recipe: DataRecipe) -> None: torch.manual_seed(self.random_seed) # Call the setup method of the user - user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None) + user_items: Union[List[Any], StreamingDataLoader] = data_recipe.prepare_structure( + self.input_dir.path if self.input_dir else None + ) if not isinstance(user_items, (list, StreamingDataLoader)): raise ValueError("The `prepare_structure` should return a list of item metadata.") @@ -970,6 +973,8 @@ def run(self, data_recipe: DataRecipe) -> None: if self.reader: user_items = self.reader.remap_items(user_items, self.num_workers) + assert isinstance(user_items, list) + if self.weights is not None: if len(self.weights) != len(user_items): raise ValueError("The provided weights length should match the inputs' length.") diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 86c2f3f4..51ca8159 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -98,7 +98,7 @@ def _get_default_num_workers() -> int: class LambdaDataTransformRecipe(DataTransformRecipe): - def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]): + def __init__(self, fn: Callable[[str, Any], None], inputs: Union[Sequence[Any], StreamingDataLoader]): super().__init__() self._fn = fn self._inputs = inputs @@ -144,7 +144,7 @@ class LambdaDataChunkRecipe(DataChunkRecipe): def __init__( self, fn: Callable[[Any], None], - inputs: Sequence[Any], + inputs: Union[Sequence[Any], StreamingDataLoader], chunk_size: Optional[int], chunk_bytes: Optional[Union[int, str]], compression: Optional[str], @@ -184,7 +184,7 @@ def prepare_item(self, item_metadata: Any) -> Any: def map( fn: Callable[[str, Any], None], - inputs: Sequence[Any], + inputs: Union[Sequence[Any], StreamingDataLoader], output_dir: Union[str, Dir], input_dir: Optional[str] = None, weights: Optional[List[int]] = None, @@ -199,12 +199,11 @@ def map( reader: Optional[BaseReader] = None, batch_size: Optional[int] = None, ) -> None: - """This function map a callbable over a collection of files possibly in a distributed way. + """This function maps a callable over a collection of inputs, possibly in a distributed way. Arguments: fn: A function to be executed over each input element - inputs: A sequence of input to be processed by the `fn` function. - Each input should contain at least a valid filepath. + inputs: A sequence of input to be processed by the `fn` function, or a streaming dataloader. output_dir: The folder where the processed data should be stored. input_dir: Provide the path where your files are stored. If the files are on a remote storage, they will be downloaded in the background while processed. @@ -228,7 +227,9 @@ def map( raise ValueError("When providing a streaming dataloader, weights isn't supported.") if not isinstance(inputs, (Sequence, StreamingDataLoader)): - raise ValueError(f"The provided inputs should be non empty sequence or a streaming dataloader. Found {inputs}.") + raise ValueError( + f"The provided inputs should be a non-empty sequence or a streaming dataloader. Found {inputs}." + ) if len(inputs) == 0: raise ValueError(f"The provided inputs should be non empty. Found {inputs}.") @@ -291,7 +292,7 @@ def map( def optimize( fn: Callable[[Any], Any], - inputs: Sequence[Any], + inputs: Union[Sequence[Any], StreamingDataLoader], output_dir: str, input_dir: Optional[str] = None, weights: Optional[List[int]] = None, @@ -311,12 +312,13 @@ def optimize( mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False, ) -> None: - """This function converts a dataset into chunks possibly in a distributed way. + """This function converts a dataset into chunks, possibly in a distributed way. Arguments: - fn: A function to be executed over each input element - inputs: A sequence of input to be processed by the `fn` function. - Each input should contain at least a valid filepath. + fn: A function to be executed over each input element. The function should return the data sample that + corresponds to the input. Every invocation of the function should return a similar hierarchy of objects, + where the object types and list sizes don't change. + inputs: A sequence of input to be processed by the `fn` function, or a streaming dataloader. output_dir: The folder where the processed data should be stored. input_dir: Provide the path where your files are stored. If the files are on a remote storage, they will be downloaded in the background while processed. @@ -350,7 +352,9 @@ def optimize( raise ValueError("When providing a streaming dataloader, weights isn't supported.") if not isinstance(inputs, (Sequence, StreamingDataLoader)): - raise ValueError(f"The provided inputs should be non empty sequence or a streaming dataloader. Found {inputs}.") + raise ValueError( + f"The provided inputs should be a non-empty sequence or a streaming dataloader. Found {inputs}." + ) if len(inputs) == 0: raise ValueError(f"The provided inputs should be non empty. Found {inputs}.") diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 7b64f863..50ef214f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -56,7 +56,7 @@ def seed_everything(random_seed): pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")), ], ) -@pytest.mark.timeout(15) +@pytest.mark.timeout(30) def test_streaming_dataset(tmpdir, monkeypatch, compression): seed_everything(42) diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index 00cc75f4..12aa52b9 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -15,7 +15,6 @@ import os import random import sys -from time import time from unittest import mock import numpy as np @@ -33,7 +32,6 @@ NoHeaderNumpySerializer, NoHeaderTensorSerializer, NumpySerializer, - PickleSerializer, PILSerializer, TensorSerializer, VideoSerializer, @@ -140,10 +138,7 @@ def test_tensor_serializer(): seed_everything(42) serializer_tensor = TensorSerializer() - serializer_pickle = PickleSerializer() - ratio_times = [] - ratio_bytes = [] shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)] for dtype in _TORCH_DTYPES_MAPPING.values(): for shape in shapes: @@ -152,30 +147,12 @@ def test_tensor_serializer(): continue tensor = torch.ones(shape, dtype=dtype) - t0 = time() data, _ = serializer_tensor.serialize(tensor) deserialized_tensor = serializer_tensor.deserialize(data) - tensor_time = time() - t0 - tensor_bytes = len(data) assert deserialized_tensor.dtype == dtype assert torch.equal(tensor, deserialized_tensor) - t1 = time() - data, _ = serializer_pickle.serialize(tensor) - deserialized_tensor = serializer_pickle.deserialize(data) - pickle_time = time() - t1 - pickle_bytes = len(data) - - assert deserialized_tensor.dtype == dtype - assert torch.equal(tensor, deserialized_tensor) - - ratio_times.append(pickle_time / tensor_time) - ratio_bytes.append(pickle_bytes / tensor_bytes) - - assert np.mean(ratio_times) > 1.6 - assert np.mean(ratio_bytes) > 2 - @pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows") def test_numpy_serializer():