Skip to content

Commit

Permalink
Hot fix: Fix path resolution (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Feb 26, 2024
1 parent 51b7fee commit f0b755a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
36 changes: 22 additions & 14 deletions litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,25 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
return item_sizes


def _to_path(element: str) -> str:
return element if _IS_IN_STUDIO and element.startswith("/teamspace") else str(Path(element).resolve())


def _is_path(input_dir: Optional[str], element: Any) -> bool:
if not isinstance(element, str):
return False

if _IS_IN_STUDIO and input_dir is not None:
if element.startswith(input_dir):
return True

element = str(Path(element).absolute())
if element.startswith(input_dir):
return True

return os.path.exists(element)


class BaseWorker:
def __init__(
self,
Expand Down Expand Up @@ -381,7 +400,6 @@ def __init__(
self.remove_queue: Queue = Queue()
self.progress_queue: Queue = progress_queue
self.error_queue: Queue = error_queue
self._collected_items = 0
self._counter = 0
self._last_time = time()
self._index_counter = 0
Expand Down Expand Up @@ -504,22 +522,13 @@ def _collect_paths(self) -> None:
for item in self.items:
flattened_item, spec = tree_flatten(item)

def is_path(element: Any) -> bool:
if not isinstance(element, str):
return False

element: str = str(Path(element).resolve())
if _IS_IN_STUDIO and self.input_dir.path is not None:
if self.input_dir.path.startswith("/teamspace/studios/this_studio"):
return os.path.exists(element)
return element.startswith(self.input_dir.path)
return os.path.exists(element)

# For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
# Other alternative would be too slow.
# TODO: Try using dictionary for higher accurary.
indexed_paths = {
index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element)
index: _to_path(element)
for index, element in enumerate(flattened_item)
if _is_path(self.input_dir.path, element)
}

if len(indexed_paths) == 0:
Expand All @@ -537,7 +546,6 @@ def is_path(element: Any) -> bool:
self.paths.append(paths)

items.append(tree_unflatten(flattened_item, spec))
self._collected_items += 1

self.items = items

Expand Down
23 changes: 23 additions & 0 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
DataTransformRecipe,
_download_data_target,
_get_item_filesizes,
_is_path,
_map_items_to_workers_sequentially,
_map_items_to_workers_weighted,
_remove_target,
_to_path,
_upload_fn,
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
Expand Down Expand Up @@ -1136,3 +1138,24 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):
tensor = torchaudio.load(sample)
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000


def test_is_path_valid_in_studio(monkeypatch, tmpdir):
filepath = os.path.join(tmpdir, "a.png")
with open(filepath, "w") as f:
f.write("Hello World")

monkeypatch.setattr(data_processor_module, "_IS_IN_STUDIO", True)

assert _is_path("/teamspace/studios/this_studio", "/teamspace/studios/this_studio/a.png")
assert _is_path("/teamspace/studios/this_studio", filepath)


@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
def test_to_path(tmpdir):
filepath = os.path.join(tmpdir, "a.png")
with open(filepath, "w") as f:
f.write("Hello World")

assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png"
assert _to_path(filepath) == filepath

0 comments on commit f0b755a

Please sign in to comment.