Skip to content

Commit

Permalink
Resolve s3 credentials wrongly defined (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Feb 26, 2024
1 parent 6683b08 commit 00930ee
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
12 changes: 5 additions & 7 deletions litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from typing import Any, Optional

from litdata.constants import _BOTO3_AVAILABLE
from litdata.constants import _BOTO3_AVAILABLE, _IS_IN_STUDIO

if _BOTO3_AVAILABLE:
import boto3
Expand All @@ -17,15 +17,14 @@ class S3Client:
def __init__(self, refetch_interval: int = 3300) -> None:
self._refetch_interval = refetch_interval
self._last_time: Optional[float] = None
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
self._client: Optional[Any] = None

def _create_client(self) -> None:
has_shared_credentials_file = (
os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials"
)

if has_shared_credentials_file:
if has_shared_credentials_file or not _IS_IN_STUDIO:
self._client = boto3.client(
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
)
Expand All @@ -42,10 +41,9 @@ def _create_client(self) -> None:

@property
def client(self) -> Any:
if not self._has_cloud_space_id:
if self._client is None:
self._create_client()
return self._client
if self._client is None:
self._create_client()
self._last_time = time()

# Re-generate credentials for EC2
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
Expand Down
7 changes: 3 additions & 4 deletions tests/streaming/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,16 @@ def test_s3_client_without_cloud_space_id(monkeypatch):


@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
@pytest.mark.parametrize("use_shared_credentials", [False, True])
@pytest.mark.parametrize("use_shared_credentials", [False, True, None])
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
boto3 = mock.MagicMock()
monkeypatch.setattr(client, "boto3", boto3)

botocore = mock.MagicMock()
monkeypatch.setattr(client, "botocore", botocore)

monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")

if use_shared_credentials:
if isinstance(use_shared_credentials, bool):
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials")
monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")

Expand Down

0 comments on commit 00930ee

Please sign in to comment.