diff --git a/rohmu/object_storage/s3.py b/rohmu/object_storage/s3.py index e0bdb3d9..f82f7c13 100644 --- a/rohmu/object_storage/s3.py +++ b/rohmu/object_storage/s3.py @@ -32,12 +32,14 @@ ) from rohmu.typing import Metadata from rohmu.util import batched, ProgressStream +from threading import RLock from typing import Any, BinaryIO, cast, Collection, Iterator, Optional, Tuple, TYPE_CHECKING, Union import botocore.client import botocore.config import botocore.exceptions import botocore.session +import contextlib import math import time @@ -124,7 +126,6 @@ def __init__( statsd_info: Optional[StatsdConfig] = None, ) -> None: super().__init__(prefix=prefix, notifier=notifier, statsd_info=statsd_info) - session = botocore.session.get_session() self.bucket_name = bucket_name self.location = "" self.region = region @@ -140,14 +141,15 @@ def __init__( custom_config["proxies"] = {"https": proxy_url} if use_dualstack_endpoint is True: custom_config["use_dualstack_endpoint"] = True - self.s3_client = create_s3_client( - session=session, - config=botocore.config.Config(**custom_config), - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - region_name=region, - ) + with self._get_session() as session: + self.s3_client = create_s3_client( + session=session, + config=botocore.config.Config(**custom_config), + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=region, + ) if self.region and self.region != "us-east-1": self.location = self.region else: @@ -173,16 +175,18 @@ def __init__( ) if not is_verify_tls and cert_path is not None: raise ValueError("cert_path is set but is_verify_tls is False") - self.s3_client = create_s3_client( - session=session, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - config=boto_config, - endpoint_url=custom_url, - region_name=region, - verify=str(cert_path) if cert_path is not None and is_verify_tls else is_verify_tls, - ) + + with self._get_session() as session: + self.s3_client = create_s3_client( + session=session, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=boto_config, + endpoint_url=custom_url, + region_name=region, + verify=str(cert_path) if cert_path is not None and is_verify_tls else is_verify_tls, + ) self.check_or_create_bucket() @@ -190,6 +194,24 @@ def __init__( self.encrypted = encrypted self.log.debug("S3Transfer initialized") + # It is advantageous to share the Session as much as possible since the very + # large service model files (eg botocore/data/ec2/2016-11-15/service-2.json) + # are cached on the Session, otherwise they will need to be loaded for every + # Client - which takes a lot of time and memory. + # Sessions are not threadsafe. We use a lock to ensure that only one thread + # is creating a client at a time. Clients are threadsafe, so it is okay for + # the Client to "escape" the lock with any state it shares with the Session. + _botocore_session_lock = RLock() + _botocore_session: botocore.session.Session | None = None + + @classmethod + @contextlib.contextmanager + def _get_session(cls) -> Iterator[botocore.session.Session]: + with cls._botocore_session_lock: + if cls._botocore_session is None: + cls._botocore_session = botocore.session.get_session() + yield cls._botocore_session + def copy_file( self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **_kwargs: Any ) -> None: diff --git a/test/object_storage/test_s3.py b/test/object_storage/test_s3.py index da21afcc..be31ecb0 100644 --- a/test/object_storage/test_s3.py +++ b/test/object_storage/test_s3.py @@ -16,6 +16,7 @@ from typing import Any, BinaryIO, Callable, Iterator, Optional, Union from unittest.mock import ANY, call, MagicMock, patch +import contextlib import pytest import rohmu.object_storage.s3 @@ -31,10 +32,16 @@ class S3Infra: @pytest.fixture(name="infra") def fixture_infra(mocker: Any) -> Iterator[S3Infra]: notifier = MagicMock() - get_session = mocker.patch("botocore.session.get_session") s3_client = MagicMock() create_client = MagicMock(return_value=s3_client) - get_session.return_value = MagicMock(create_client=create_client) + session = MagicMock(create_client=create_client) + + @contextlib.contextmanager + def _get_session(cls: S3Transfer) -> Iterator[MagicMock]: + yield session + + mocker.patch("rohmu.object_storage.s3.S3Transfer._get_session", _get_session) + operation = mocker.patch("rohmu.common.statsd.StatsClient.operation") transfer = S3Transfer( region="test-region",