From 00930ee9c2a9567e67d95dc404008c75127f653e Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 26 Feb 2024 11:34:49 +0000 Subject: [PATCH] Resolve s3 credentials wrongly defined (#27) --- litdata/streaming/client.py | 12 +++++------- tests/streaming/test_client.py | 7 +++---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/litdata/streaming/client.py b/litdata/streaming/client.py index 0a12616d..28a0566f 100644 --- a/litdata/streaming/client.py +++ b/litdata/streaming/client.py @@ -2,7 +2,7 @@ from time import time from typing import Any, Optional -from litdata.constants import _BOTO3_AVAILABLE +from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO if _BOTO3_AVAILABLE: import boto3 @@ -17,7 +17,6 @@ class S3Client: def __init__(self, refetch_interval: int = 3300) -> None: self._refetch_interval = refetch_interval self._last_time: Optional[float] = None - self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ self._client: Optional[Any] = None def _create_client(self) -> None: @@ -25,7 +24,7 @@ def _create_client(self) -> None: os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" ) - if has_shared_credentials_file: + if has_shared_credentials_file or not _IS_IN_STUDIO: self._client = boto3.client( "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) ) @@ -42,10 +41,9 @@ def _create_client(self) -> None: @property def client(self) -> Any: - if not self._has_cloud_space_id: - if self._client is None: - self._create_client() - return self._client + if self._client is None: + self._create_client() + self._last_time = time() # Re-generate credentials for EC2 if self._last_time is None or (time() - self._last_time) > self._refetch_interval: diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py index ca933604..64163a64 100644 --- a/tests/streaming/test_client.py +++ b/tests/streaming/test_client.py @@ -31,7 +31,7 @@ def test_s3_client_without_cloud_space_id(monkeypatch): @pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") -@pytest.mark.parametrize("use_shared_credentials", [False, True]) +@pytest.mark.parametrize("use_shared_credentials", [False, True, None]) def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): boto3 = mock.MagicMock() monkeypatch.setattr(client, "boto3", boto3) @@ -39,9 +39,8 @@ def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): botocore = mock.MagicMock() monkeypatch.setattr(client, "botocore", botocore) - monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") - - if use_shared_credentials: + if isinstance(use_shared_credentials, bool): + monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")