diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index c2ed1b26..26d38ad0 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -1019,7 +1019,7 @@ def run(self, data_recipe: DataRecipe) -> None: print("Workers are finished.") result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir) - if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO and self.input_dir.path: + if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO: assert self.output_dir.path _create_dataset( input_dir=self.input_dir.path, diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 0e4e5372..f1b18b1a 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, List from unittest import mock +from unittest.mock import ANY, Mock import numpy as np import pytest @@ -901,7 +902,7 @@ def map_fn_index(index, output_dir): @pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']") -def test_data_processing_map_without_input_dir(monkeypatch, tmpdir): +def test_data_processing_map_without_input_dir_local(monkeypatch, tmpdir): cache_dir = os.path.join(tmpdir, "cache") output_dir = os.path.join(tmpdir, "target_dir") os.makedirs(output_dir, exist_ok=True) @@ -920,6 +921,46 @@ def test_data_processing_map_without_input_dir(monkeypatch, tmpdir): assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"] +@pytest.mark.skipif(sys.platform == "win32", reason="Windows not supported") +def test_data_processing_map_without_input_dir_remote(monkeypatch, tmpdir): + cache_dir = os.path.join(tmpdir, "cache") + output_dir = os.path.join("/teamspace", "datasets", "target_dir") + + monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) + monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir) + + create_dataset_mock = Mock() + monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "1") + monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "2") + monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "3") + monkeypatch.setattr("litdata.processing.data_processor._IS_IN_STUDIO", True) + monkeypatch.setattr( + "litdata.streaming.resolver._resolve_datasets", + Mock(return_value=Dir(path=tmpdir / "output", url="url")), + ) + monkeypatch.setattr("litdata.processing.data_processor._create_dataset", create_dataset_mock) + + map( + map_fn_index, + list(range(5)), + output_dir=output_dir, + num_workers=1, + ) + + create_dataset_mock.assert_called_with( + input_dir=None, + storage_dir=str(tmpdir / "output"), + dataset_type=ANY, + empty=ANY, + size=ANY, + num_bytes=ANY, + data_format=ANY, + compression=ANY, + num_chunks=ANY, + num_bytes_per_chunk=ANY, + ) + + @pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']") def test_data_processing_map_weights_mismatch(monkeypatch, tmpdir): cache_dir = os.path.join(tmpdir, "cache")