diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index fe8b132e..ca43bedd 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -38,7 +38,9 @@ def subsample_streaming_dataset( # Make sure input_dir contains cache path and remote url if _should_replace_path(input_dir.path): - cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url) + cache_path = _try_create_cache_dir( + input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options + ) if cache_path is not None: input_dir.path = cache_path @@ -96,7 +98,7 @@ def _should_replace_path(path: Optional[str]) -> bool: return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") -def _read_updated_at(input_dir: Optional[Dir]) -> str: +def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" index_json_content = None @@ -110,7 +112,7 @@ def _read_updated_at(input_dir: Optional[Dir]) -> str: # download index.json file and read last_updation_timestamp with tempfile.TemporaryDirectory() as tmp_directory: temp_index_filepath = os.path.join(tmp_directory, _INDEX_FILENAME) - downloader = get_downloader_cls(input_dir.url, tmp_directory, []) + downloader = get_downloader_cls(input_dir.url, tmp_directory, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath) index_json_content = load_index_file(tmp_directory) @@ -135,9 +137,9 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str]) -> Optional[str]: +def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) - updated_at = _read_updated_at(resolved_input_dir) + updated_at = _read_updated_at(resolved_input_dir, storage_options) if updated_at == "0" and input_dir is not None: updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324