Skip to content

Commit

Permalink
Merge branch 'main' into patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy authored Sep 18, 2024
2 parents b3cc3a5 + c9aaf24 commit ecb0e17
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/litdata/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _check_requirement(self) -> None:
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
self.message = f"{ex.__class__.__name__}: {ex}.\n HINT: Try running `pip install -U {self.requirement!r}`"
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
if not requirement_contains_version_specifier or self.module is not None:
module = self.requirement if self.module is None else self.module
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def map(
if _output_dir.url and "cloudspaces" in _output_dir.url:
raise ValueError(
f"The provided `output_dir` isn't valid. Found {_output_dir.path if _output_dir else None}."
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
"\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)

if error_when_not_empty:
Expand Down Expand Up @@ -400,7 +400,7 @@ def optimize(
if _output_dir.url is not None and "cloudspaces" in _output_dir.url:
raise ValueError(
f"The provided `output_dir` isn't valid. Found {_output_dir.path}."
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
"\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)

_assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint)
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _create_dataset(
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
user_id = os.getenv("LIGHTNING_USER_ID", None)
cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
studio_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)

if project_id is None:
Expand All @@ -64,7 +64,7 @@ def _create_dataset(
try:
client.dataset_service_create_dataset(
body=ProjectIdDatasetsBody(
cloud_space_id=cloud_space_id if lightning_app_id is None else None,
cloud_space_id=studio_id if lightning_app_id is None else None,
cluster_id=cluster_id,
creator_id=user_id,
empty=empty,
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __getitem__(self, index: int) -> Any:
if not _equal_items(data_1, data2):
raise ValueError(
f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}."
" HINT: Use the `litdata.cache.Cache` directly within your dataset."
"\n HINT: Use the `litdata.cache.Cache` directly within your dataset."
)
self._is_deterministic = True
self._cache[index] = data_1
Expand All @@ -115,7 +115,7 @@ class CacheCollateFn:
During the chunking phase, there is no need to return any data from the DataLoader reducing some time.
Additionally, if the user makes their __getitem__ asynchronous, the collate executes them in parallel.
Additionally, if the user makes their __getitem__ asynchronous, collate executes them in parallel.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
if not cache.filled:
raise ValueError(
f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file."
" HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
"\n HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
)

return cache
Expand Down
34 changes: 17 additions & 17 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)

if cluster_id is None:
raise RuntimeError("The `cluster_id` couldn't be found from the environment variables.")
raise RuntimeError("The `LIGHTNING_CLUSTER_ID` couldn't be found from the environment variables.")

if project_id is None:
raise RuntimeError("The `project_id` couldn't be found from the environment variables.")
raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.")

clusters = client.cluster_service_list_project_clusters(project_id).clusters

Expand Down Expand Up @@ -147,7 +147,7 @@ def _resolve_s3_connections(dir_path: str) -> Dir:
# Get the ids from env variables
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
if project_id is None:
raise RuntimeError("The `project_id` couldn't be found from the environment variables.")
raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.")

target_name = dir_path.split("/")[3]

Expand All @@ -169,16 +169,16 @@ def _resolve_datasets(dir_path: str) -> Dir:
# Get the ids from env variables
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
studio_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)

if cluster_id is None:
raise RuntimeError("The `cluster_id` couldn't be found from the environment variables.")
raise RuntimeError("The `LIGHTNING_CLUSTER_ID` couldn't be found from the environment variables.")

if project_id is None:
raise RuntimeError("The `project_id` couldn't be found from the environment variables.")
raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.")

if cloud_space_id is None:
raise RuntimeError("The `cloud_space_id` couldn't be found from the environment variables.")
if studio_id is None:
raise RuntimeError("The `LIGHTNING_CLOUD_SPACE_ID` couldn't be found from the environment variables.")

clusters = client.cluster_service_list_project_clusters(project_id).clusters

Expand All @@ -187,17 +187,17 @@ def _resolve_datasets(dir_path: str) -> Dir:
for cloudspace in client.cloud_space_service_list_cloud_spaces(
project_id=project_id, cluster_id=cluster_id
).cloudspaces
if cloudspace.id == cloud_space_id
if cloudspace.id == studio_id
]

if not target_cloud_space:
raise ValueError(f"We didn't find any matching Studio for the provided id `{cloud_space_id}`.")
raise ValueError(f"We didn't find any matching Studio for the provided id `{studio_id}`.")

target_cluster = [cluster for cluster in clusters if cluster.id == target_cloud_space[0].cluster_id]

if not target_cluster:
raise ValueError(
f"We didn't find a matching cluster associated with the id {target_cloud_space[0].cluster_id}."
f"We didn't find a matching cluster associated with the id `{target_cloud_space[0].cluster_id}`."
)

return Dir(
Expand All @@ -211,7 +211,7 @@ def _resolve_datasets(dir_path: str) -> Dir:

def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool = False) -> None:
if not isinstance(output_dir, Dir):
raise ValueError("The provided output_dir isn't a Dir Object.")
raise ValueError("The provided output_dir isn't a `Dir` Object.")

if output_dir.url is None:
return
Expand All @@ -234,7 +234,7 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool
if objects["KeyCount"] > 0:
raise RuntimeError(
f"The provided output_dir `{output_dir.path}` already contains data and datasets are meant to be immutable."
" HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?"
"\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?"
)


Expand All @@ -261,8 +261,8 @@ def _assert_dir_has_index_file(
if os.path.exists(index_file) and mode is None:
raise RuntimeError(
f"The provided output_dir `{output_dir.path}` already contains an optimized immutable datasets."
" HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?"
" HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`."
"\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?"
"\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`."
)

# delete index.json file and chunks
Expand Down Expand Up @@ -310,8 +310,8 @@ def _assert_dir_has_index_file(
if has_index_file and mode is None:
raise RuntimeError(
f"The provided output_dir `{output_dir.path}` already contains an optimized immutable datasets."
" HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?"
" HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`."
"\n HINT: Did you consider changing the `output_dir` with your own versioning as a suffix?"
"\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`."
)

# Delete all the files (including the index file in overwrite mode)
Expand Down
3 changes: 2 additions & 1 deletion src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
if isinstance(item, JpegImageFile):
if not hasattr(item, "filename"):
raise ValueError(
"The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method."
"The JPEG Image's filename isn't defined."
"\n HINT: Open the image in your Dataset `__getitem__` method."
)
if item.filename and os.path.isfile(item.filename):
# read the content of the file directly
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def subsample_streaming_dataset(
else:
raise ValueError(
f"The provided dataset `{input_dir.path}` doesn't contain any {_INDEX_FILENAME} file."
" HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
"\n HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
)

assert len(original_chunks) > 0, f"No chunks found in the `{input_dir}/index.json` file"
Expand Down
14 changes: 8 additions & 6 deletions tests/streaming/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def test_src_resolver_s3_connections(monkeypatch, lightning_cloud_mock):
auth = login.Auth()
auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e")

with pytest.raises(RuntimeError, match="`project_id` couldn't be found from the environment variables."):
with pytest.raises(
RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables."
):
resolver._resolve_dir("/teamspace/s3_connections/imagenet")

monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id")
Expand Down Expand Up @@ -60,12 +62,12 @@ def test_src_resolver_studios(monkeypatch, lightning_cloud_mock):
auth = login.Auth()
auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e")

with pytest.raises(RuntimeError, match="`cluster_id`"):
with pytest.raises(RuntimeError, match="`LIGHTNING_CLUSTER_ID`"):
resolver._resolve_dir("/teamspace/studios/other_studio")

monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id")

with pytest.raises(RuntimeError, match="`project_id`"):
with pytest.raises(RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID`"):
resolver._resolve_dir("/teamspace/studios/other_studio")

monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id")
Expand Down Expand Up @@ -138,17 +140,17 @@ def test_src_resolver_datasets(monkeypatch, lightning_cloud_mock):

assert resolver._resolve_dir("s3://bucket_name").url == "s3://bucket_name"

with pytest.raises(RuntimeError, match="`cluster_id`"):
with pytest.raises(RuntimeError, match="`LIGHTNING_CLUSTER_ID`"):
resolver._resolve_dir("/teamspace/datasets/imagenet")

monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id")

with pytest.raises(RuntimeError, match="`project_id`"):
with pytest.raises(RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID`"):
resolver._resolve_dir("/teamspace/datasets/imagenet")

monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id")

with pytest.raises(RuntimeError, match="`cloud_space_id`"):
with pytest.raises(RuntimeError, match="`LIGHTNING_CLOUD_SPACE_ID`"):
resolver._resolve_dir("/teamspace/datasets/imagenet")

monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "cloud_space_id")
Expand Down

0 comments on commit ecb0e17

Please sign in to comment.