diff --git a/boto3/crt.py b/boto3/crt.py index bf032e0281..4b8df3140e 100644 --- a/boto3/crt.py +++ b/boto3/crt.py @@ -32,8 +32,8 @@ ) # Singletons for CRT-backed transfers -_CRT_S3_CLIENT = None -_BOTOCORE_CRT_SERIALIZER = None +CRT_S3_CLIENT = None +BOTOCORE_CRT_SERIALIZER = None CLIENT_CREATION_LOCK = threading.Lock() PROCESS_LOCK_NAME = 'boto3' @@ -59,15 +59,10 @@ def _create_crt_request_serializer(session, region_name): ) -def _create_crt_s3_client(session, config, region_name, credentials, **kwargs): +def _create_crt_s3_client( + session, config, region_name, credentials, lock, **kwargs +): """Create boto3 wrapper class to manage crt lock reference and S3 client.""" - lock = acquire_crt_s3_process_lock(PROCESS_LOCK_NAME) - if lock is None: - # If we're unable to acquire the lock, we cannot - # use the CRT in this process and should default to - # the classic s3transfer manager. - return None - cred_wrapper = BotocoreCRTCredentialsWrapper(credentials) cred_provider = cred_wrapper.to_crt_credentials_provider() return CRTS3Client( @@ -79,30 +74,37 @@ def _create_crt_s3_client(session, config, region_name, credentials, **kwargs): def _initialize_crt_transfer_primatives(client, config): + lock = acquire_crt_s3_process_lock(PROCESS_LOCK_NAME) + if lock is None: + # If we're unable to acquire the lock, we cannot + # use the CRT in this process and should default to + # the classic s3transfer manager. + return None, None + session = Session() region_name = client.meta.region_name credentials = client._get_credentials() serializer = _create_crt_request_serializer(session, region_name) s3_client = _create_crt_s3_client( - session, config, region_name, credentials + session, config, region_name, credentials, lock ) return serializer, s3_client def get_crt_s3_client(client, config): - global _CRT_S3_CLIENT - global _BOTOCORE_CRT_SERIALIZER + global CRT_S3_CLIENT + global BOTOCORE_CRT_SERIALIZER with CLIENT_CREATION_LOCK: - if _CRT_S3_CLIENT is None: + if CRT_S3_CLIENT is None: serializer, s3_client = _initialize_crt_transfer_primatives( client, config ) - _BOTOCORE_CRT_SERIALIZER = serializer - _CRT_S3_CLIENT = s3_client + BOTOCORE_CRT_SERIALIZER = serializer + CRT_S3_CLIENT = s3_client - return _CRT_S3_CLIENT + return CRT_S3_CLIENT class CRTS3Client: @@ -125,15 +127,19 @@ def __init__(self, crt_client, process_lock, region, cred_provider): def is_crt_compatible_request(client, crt_s3_client): """ Boto3 client must use same signing region and credentials - as the _CRT_S3_CLIENT singleton. Otherwise fallback to classic. + as the CRT_S3_CLIENT singleton. Otherwise fallback to classic. """ if crt_s3_client is None: return False - is_same_region = client.meta.region_name == crt_s3_client.region + boto3_creds = client._get_credentials() + if boto3_creds is None: + return False + is_same_identity = compare_identity( - client._get_credentials(), crt_s3_client.cred_provider + boto3_creds.get_frozen_credentials(), crt_s3_client.cred_provider ) + is_same_region = client.meta.region_name == crt_s3_client.region return is_same_region and is_same_identity @@ -156,6 +162,6 @@ def create_crt_transfer_manager(client, config): crt_s3_client = get_crt_s3_client(client, config) if is_crt_compatible_request(client, crt_s3_client): return CRTTransferManager( - crt_s3_client.crt_client, _BOTOCORE_CRT_SERIALIZER + crt_s3_client.crt_client, BOTOCORE_CRT_SERIALIZER ) return None diff --git a/tests/unit/test_crt.py b/tests/unit/test_crt.py index e7bc6b47c0..c88eb36fc2 100644 --- a/tests/unit/test_crt.py +++ b/tests/unit/test_crt.py @@ -21,12 +21,44 @@ from tests import mock, requires_crt if HAS_CRT: - import awscrt.s3 + from awscrt.s3 import CrossProcessLock as CrossProcessLockClass from s3transfer.crt import BotocoreCRTCredentialsWrapper import boto3.crt +@pytest.fixture +def mock_crt_process_lock(monkeypatch): + # The process lock is cached at the module layer whenever the + # cross process lock is successfully acquired. This patch ensures that + # test cases will start off with no previously cached process lock and + # if a cross process is instantiated/acquired it will be the mock that + # can be used for controlling lock behavior. + if HAS_CRT: + monkeypatch.setattr('s3transfer.crt.CRT_S3_PROCESS_LOCK', None) + with mock.patch('awscrt.s3.CrossProcessLock', spec=True) as mock_lock: + yield mock_lock + else: + # We cannot mock or use the lock without CRT support. + yield None + + +@pytest.fixture +def mock_crt_client_singleton(monkeypatch): + # Clear CRT state for each test + if HAS_CRT: + monkeypatch.setattr('boto3.crt.CRT_S3_CLIENT', None) + yield None + + +@pytest.fixture +def mock_serializer_singleton(monkeypatch): + # Clear CRT state for each test + if HAS_CRT: + monkeypatch.setattr('boto3.crt.BOTOCORE_CRT_SERIALIZER', None) + yield None + + def create_test_client(service_name='s3', region_name="us-east-1"): return boto3.client( service_name, @@ -43,21 +75,35 @@ def create_test_client(service_name='s3', region_name="us-east-1"): class TestCRTTransferManager: @requires_crt() - def test_create_crt_transfer_manager_with_lock_in_use(self): - with mock.patch('boto3.crt.acquire_crt_s3_process_lock') as lock: - lock.return_value = None - - # Verify we can't create a second CRT client - tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None) - assert tm is None + def test_create_crt_transfer_manager_with_lock_in_use( + self, + mock_crt_process_lock, + mock_crt_client_singleton, + mock_serializer_singleton, + ): + mock_crt_process_lock.return_value.acquire.side_effect = RuntimeError + + # Verify we can't create a second CRT client + tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None) + assert tm is None @requires_crt() - def test_create_crt_transfer_manager(self): + def test_create_crt_transfer_manager( + self, + mock_crt_process_lock, + mock_crt_client_singleton, + mock_serializer_singleton, + ): tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None) assert isinstance(tm, s3transfer.crt.CRTTransferManager) @requires_crt() - def test_crt_singleton_is_returned_every_call(self): + def test_crt_singleton_is_returned_every_call( + self, + mock_crt_process_lock, + mock_crt_client_singleton, + mock_serializer_singleton, + ): first_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, None) second_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, None) @@ -66,7 +112,12 @@ def test_crt_singleton_is_returned_every_call(self): assert first_s3_client.crt_client is second_s3_client.crt_client @requires_crt() - def test_create_crt_transfer_manager_w_client_in_wrong_region(self): + def test_create_crt_transfer_manager_w_client_in_wrong_region( + self, + mock_crt_process_lock, + mock_crt_client_singleton, + mock_serializer_singleton, + ): """Ensure we don't return the crt transfer manager if client is in different region. The CRT isn't able to handle region redirects and will consistently fail. @@ -130,20 +181,28 @@ def no_credentials(): ) @requires_crt() - def test_get_crt_s3_client(self): + def test_get_crt_s3_client( + self, + mock_crt_process_lock, + mock_crt_client_singleton, + mock_serializer_singleton, + ): config = TransferConfig() crt_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, config) assert isinstance(crt_s3_client, boto3.crt.CRTS3Client) - assert isinstance( - crt_s3_client.process_lock, awscrt.s3.CrossProcessLock - ) + assert isinstance(crt_s3_client.process_lock, CrossProcessLockClass) assert crt_s3_client.region == "us-west-2" assert isinstance( crt_s3_client.cred_provider, BotocoreCRTCredentialsWrapper ) @requires_crt() - def test_get_crt_s3_client_w_wrong_region(self): + def test_get_crt_s3_client_w_wrong_region( + self, + mock_crt_process_lock, + mock_crt_client_singleton, + mock_serializer_singleton, + ): config = TransferConfig() crt_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, config) assert isinstance(crt_s3_client, boto3.crt.CRTS3Client)