From f0b755a3fc7eb1514ceafab82c580d08de4003e4 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 26 Feb 2024 11:53:43 +0000 Subject: [PATCH] Hot fix: Fix path resolution (#29) --- litdata/processing/data_processor.py | 36 +++++++++++++++---------- tests/processing/test_data_processor.py | 23 ++++++++++++++++ 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/litdata/processing/data_processor.py b/litdata/processing/data_processor.py index e752b0a5..528720e6 100644 --- a/litdata/processing/data_processor.py +++ b/litdata/processing/data_processor.py @@ -339,6 +339,25 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]: return item_sizes +def _to_path(element: str) -> str: + return element if _IS_IN_STUDIO and element.startswith("/teamspace") else str(Path(element).resolve()) + + +def _is_path(input_dir: Optional[str], element: Any) -> bool: + if not isinstance(element, str): + return False + + if _IS_IN_STUDIO and input_dir is not None: + if element.startswith(input_dir): + return True + + element = str(Path(element).absolute()) + if element.startswith(input_dir): + return True + + return os.path.exists(element) + + class BaseWorker: def __init__( self, @@ -381,7 +400,6 @@ def __init__( self.remove_queue: Queue = Queue() self.progress_queue: Queue = progress_queue self.error_queue: Queue = error_queue - self._collected_items = 0 self._counter = 0 self._last_time = time() self._index_counter = 0 @@ -504,22 +522,13 @@ def _collect_paths(self) -> None: for item in self.items: flattened_item, spec = tree_flatten(item) - def is_path(element: Any) -> bool: - if not isinstance(element, str): - return False - - element: str = str(Path(element).resolve()) - if _IS_IN_STUDIO and self.input_dir.path is not None: - if self.input_dir.path.startswith("/teamspace/studios/this_studio"): - return os.path.exists(element) - return element.startswith(self.input_dir.path) - return os.path.exists(element) - # For speed reasons, we assume starting with `self.input_dir` is enough to be a real file. # Other alternative would be too slow. # TODO: Try using dictionary for higher accurary. indexed_paths = { - index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element) + index: _to_path(element) + for index, element in enumerate(flattened_item) + if _is_path(self.input_dir.path, element) } if len(indexed_paths) == 0: @@ -537,7 +546,6 @@ def is_path(element: Any) -> bool: self.paths.append(paths) items.append(tree_unflatten(flattened_item, spec)) - self._collected_items += 1 self.items = items diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 911997cd..0d715e3d 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -21,9 +21,11 @@ DataTransformRecipe, _download_data_target, _get_item_filesizes, + _is_path, _map_items_to_workers_sequentially, _map_items_to_workers_weighted, _remove_target, + _to_path, _upload_fn, _wait_for_disk_usage_higher_than_threshold, _wait_for_file_to_exist, @@ -1136,3 +1138,24 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression): tensor = torchaudio.load(sample) assert tensor[0].shape == torch.Size([1, 16000]) assert tensor[1] == 16000 + + +def test_is_path_valid_in_studio(monkeypatch, tmpdir): + filepath = os.path.join(tmpdir, "a.png") + with open(filepath, "w") as f: + f.write("Hello World") + + monkeypatch.setattr(data_processor_module, "_IS_IN_STUDIO", True) + + assert _is_path("/teamspace/studios/this_studio", "/teamspace/studios/this_studio/a.png") + assert _is_path("/teamspace/studios/this_studio", filepath) + + +@pytest.mark.skipif(sys.platform == "win32", reason="skip windows") +def test_to_path(tmpdir): + filepath = os.path.join(tmpdir, "a.png") + with open(filepath, "w") as f: + f.write("Hello World") + + assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png" + assert _to_path(filepath) == filepath