From 09cfea1edd2858bf153d190c3eab53ae5e59178e Mon Sep 17 00:00:00 2001 From: MohanReddy Date: Fri, 6 Sep 2024 12:23:12 +0530 Subject: [PATCH 1/2] azure storage options --- src/litdata/utilities/dataset_utilities.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index fe8b132e..b2119866 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -38,7 +38,7 @@ def subsample_streaming_dataset( # Make sure input_dir contains cache path and remote url if _should_replace_path(input_dir.path): - cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url) + cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options) if cache_path is not None: input_dir.path = cache_path @@ -96,7 +96,7 @@ def _should_replace_path(path: Optional[str]) -> bool: return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") -def _read_updated_at(input_dir: Optional[Dir]) -> str: +def _read_updated_at(input_dir: Optional[Dir],storage_options: Optional[Dict] = {}) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" index_json_content = None @@ -110,7 +110,7 @@ def _read_updated_at(input_dir: Optional[Dir]) -> str: # download index.json file and read last_updation_timestamp with tempfile.TemporaryDirectory() as tmp_directory: temp_index_filepath = os.path.join(tmp_directory, _INDEX_FILENAME) - downloader = get_downloader_cls(input_dir.url, tmp_directory, []) + downloader = get_downloader_cls(input_dir.url, tmp_directory, [] ,storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath) index_json_content = load_index_file(tmp_directory) @@ -135,9 +135,9 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str]) -> Optional[str]: +def _try_create_cache_dir(input_dir: Optional[str] , storage_options: Optional[Dict] = {}) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) - updated_at = _read_updated_at(resolved_input_dir) + updated_at = _read_updated_at(resolved_input_dir ,storage_options) if updated_at == "0" and input_dir is not None: updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324 From a5f0ae627e1c907014635372fd832f0f2c6fcffc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 06:55:40 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/dataset_utilities.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index b2119866..ca43bedd 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -38,7 +38,9 @@ def subsample_streaming_dataset( # Make sure input_dir contains cache path and remote url if _should_replace_path(input_dir.path): - cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options) + cache_path = _try_create_cache_dir( + input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options + ) if cache_path is not None: input_dir.path = cache_path @@ -96,7 +98,7 @@ def _should_replace_path(path: Optional[str]) -> bool: return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") -def _read_updated_at(input_dir: Optional[Dir],storage_options: Optional[Dict] = {}) -> str: +def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str: """Read last updated timestamp from index.json file.""" last_updation_timestamp = "0" index_json_content = None @@ -110,7 +112,7 @@ def _read_updated_at(input_dir: Optional[Dir],storage_options: Optional[Dict] = # download index.json file and read last_updation_timestamp with tempfile.TemporaryDirectory() as tmp_directory: temp_index_filepath = os.path.join(tmp_directory, _INDEX_FILENAME) - downloader = get_downloader_cls(input_dir.url, tmp_directory, [] ,storage_options) + downloader = get_downloader_cls(input_dir.url, tmp_directory, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath) index_json_content = load_index_file(tmp_directory) @@ -135,9 +137,9 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str] , storage_options: Optional[Dict] = {}) -> Optional[str]: +def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) - updated_at = _read_updated_at(resolved_input_dir ,storage_options) + updated_at = _read_updated_at(resolved_input_dir, storage_options) if updated_at == "0" and input_dir is not None: updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324