diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 45a1bd6c..3abe1b41 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -131,6 +131,7 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: Raises: ValueError: If the index file format is invalid. FileNotFoundError: If the index file does not exist in the input directory. + """ index_filepath = os.path.join(input_dir, _INDEX_FILENAME) try: @@ -139,10 +140,9 @@ def load_index_file(input_dir: str) -> Dict[str, Any]: if "chunks" in data: return data - elif "shards" in data: + if "shards" in data: return adapt_mds_shards_to_chunks(data) - else: - raise ValueError(f"Invalid index file format at {index_filepath}.") + raise ValueError(f"Invalid index file format at {index_filepath}.") except FileNotFoundError: raise FileNotFoundError(f"Index file not found at {index_filepath}.") diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index e37cc111..2384dd81 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -996,6 +996,7 @@ def test_subsample_streaming_dataset_with_token_loader(tmpdir, monkeypatch): assert len(dataset2) == int(len(dataset1) * 0.4) + @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") def test_dataset_with_mosaic_mds_data(tmpdir): from PIL import Image