diff --git a/pyproject.toml b/pyproject.toml index 63c61d93..f262c10a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,7 @@ files = [ ] # This section is for folders with "-" as they are not valid python modules exclude = [ + "src/litdata/utilities/_pytree.py", ] install_types = "True" non_interactive = "True" diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index c8a79ca9..b8b6be3d 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -20,7 +20,6 @@ import signal import tempfile import traceback -import types from abc import abstractmethod from dataclasses import dataclass from multiprocessing import Process, Queue @@ -625,7 +624,7 @@ def _handle_data_chunk_recipe(self, index: int) -> None: try: current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index]) item_data_or_generator = self.data_recipe.prepare_item(current_item) - if isinstance(item_data_or_generator, types.GeneratorType): + if self.data_recipe.is_generator: for item_data in item_data_or_generator: if item_data is not None: chunk_filepath = self.cache._add_item(self._index_counter, item_data) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 14387425..a07d7b65 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -130,26 +130,31 @@ def __init__( super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) self._fn = fn self._inputs = inputs + self.is_generator = False + + self.check_fn() + + self.prepare_item = self._prepare_item_generator if self.is_generator else self._prepare_item # type: ignore + + def check_fn(self) -> None: + if ( + isinstance(self._fn, (partial, FunctionType)) + and inspect.isgeneratorfunction(self._fn) + or (callable(self._fn) and inspect.isgeneratorfunction(self._fn.__call__)) # type: ignore + ): + self.is_generator = True + + def _prepare_item(self, item_metadata: Any) -> Any: + return self._fn(item_metadata) + + def _prepare_item_generator(self, item_metadata: Any) -> Any: + yield from self._fn(item_metadata) # type: ignore def prepare_structure(self, input_dir: Optional[str]) -> Any: return self._inputs def prepare_item(self, item_metadata: Any) -> Any: - if isinstance(self._fn, partial): - yield from self._fn(item_metadata) - - elif isinstance(self._fn, FunctionType): - if inspect.isgeneratorfunction(self._fn): - yield from self._fn(item_metadata) - else: - yield self._fn(item_metadata) - elif callable(self._fn): - if inspect.isgeneratorfunction(self._fn.__call__): # type: ignore - yield from self._fn.__call__(item_metadata) # type: ignore - else: - yield self._fn.__call__(item_metadata) # type: ignore - else: - raise ValueError(f"The provided {self._fn} isn't supported.") + """This method is overriden dynamically.""" def map( diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 21809a62..2b7aa3cb 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -142,7 +142,9 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]: # Flatten the items provided by the users flattened, data_spec = tree_flatten(items) - is_single_tensor = len(flattened) == 1 and isinstance(flattened[0], torch.Tensor) + is_single_tensor = ( + len(flattened) == 1 and isinstance(flattened[0], torch.Tensor) and len(flattened[0].shape) == 1 + ) # Collect the sizes and associated bytes for each item sizes: List[int] = [] diff --git a/status.json b/status.json deleted file mode 100644 index 72aa5495..00000000 --- a/status.json +++ /dev/null @@ -1 +0,0 @@ -{ "progress": "20.0%" } diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 2c167ccf..86cf7da8 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -362,6 +362,8 @@ def test_map_items_to_workers_sequentially(monkeypatch): class CustomDataChunkRecipe(DataChunkRecipe): + is_generator = False + def prepare_structure(self, input_dir: str) -> List[Any]: filepaths = [os.path.join(input_dir, f) for f in os.listdir(input_dir)] assert len(filepaths) == 30 @@ -553,6 +555,8 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, class TextTokenizeRecipe(DataChunkRecipe): + is_generator = True + def prepare_structure(self, input_dir: str) -> List[Any]: return [os.path.join(input_dir, "dummy.txt")]