From 0a97def7eaaa8c38072da62e470cde8dfb7d3932 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 4 Dec 2024 09:36:51 +0000 Subject: [PATCH] Add example on how to filter illegal data (#432) --- README.md | 62 ++++++++++++++++++++++++++++++++++++++++ src/litdata/__about__.py | 2 +- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index acaedcab..0bbd5090 100644 --- a/README.md +++ b/README.md @@ -399,6 +399,68 @@ for batch in tqdm(train_dataloader): +
+ ✅ Filter illegal data +  + +Sometimes, you have bad data that you don't want to include in the optimized dataset. With LitData, yield only the good data sample to include. + + +```python +from litdata import optimize, StreamingDataset + +def should_keep(index) -> bool: + # Replace with your own logic + return index % 2 == 0 + + +def fn(data): + if should_keep(data): + yield data + +if __name__ == "__main__": + optimize( + fn=fn, + inputs=list(range(1000)), + output_dir="only_even_index_optimized", + chunk_bytes="64MB", + num_workers=1 + ) + + dataset = StreamingDataset("only_even_index_optimized") + data = list(dataset) + print(data) + # [0, 2, 4, 6, 8, 10, ..., 992, 994, 996, 998] +``` + +You can even use try/expect. + +```python +from litdata import optimize, StreamingDataset + +def fn(data): + try: + yield 1 / data + except: + pass + +if __name__ == "__main__": + optimize( + fn=fn, + inputs=[0, 0, 0, 1, 2, 4, 0], + output_dir="only_defined_ratio_optimized", + chunk_bytes="64MB", + num_workers=1 + ) + + dataset = StreamingDataset("only_defined_ratio_optimized") + data = list(dataset) + # The 0 are filtered out as they raise a division by zero + print(data) + # [1.0, 0.5, 0.25] +``` +
+
✅ Combine datasets   diff --git a/src/litdata/__about__.py b/src/litdata/__about__.py index 4b3982a4..82436a84 100644 --- a/src/litdata/__about__.py +++ b/src/litdata/__about__.py @@ -14,7 +14,7 @@ import time -__version__ = "0.2.33" +__version__ = "0.2.34" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0"