Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
nateprewitt committed Nov 27, 2023
1 parent 297604d commit 2118dc9
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 37 deletions.
48 changes: 27 additions & 21 deletions boto3/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)

# Singletons for CRT-backed transfers
_CRT_S3_CLIENT = None
_BOTOCORE_CRT_SERIALIZER = None
CRT_S3_CLIENT = None
BOTOCORE_CRT_SERIALIZER = None

CLIENT_CREATION_LOCK = threading.Lock()
PROCESS_LOCK_NAME = 'boto3'
Expand All @@ -59,15 +59,10 @@ def _create_crt_request_serializer(session, region_name):
)


def _create_crt_s3_client(session, config, region_name, credentials, **kwargs):
def _create_crt_s3_client(
session, config, region_name, credentials, lock, **kwargs
):
"""Create boto3 wrapper class to manage crt lock reference and S3 client."""
lock = acquire_crt_s3_process_lock(PROCESS_LOCK_NAME)
if lock is None:
# If we're unable to acquire the lock, we cannot
# use the CRT in this process and should default to
# the classic s3transfer manager.
return None

cred_wrapper = BotocoreCRTCredentialsWrapper(credentials)
cred_provider = cred_wrapper.to_crt_credentials_provider()
return CRTS3Client(
Expand All @@ -79,30 +74,37 @@ def _create_crt_s3_client(session, config, region_name, credentials, **kwargs):


def _initialize_crt_transfer_primatives(client, config):
lock = acquire_crt_s3_process_lock(PROCESS_LOCK_NAME)
if lock is None:
# If we're unable to acquire the lock, we cannot
# use the CRT in this process and should default to
# the classic s3transfer manager.
return None, None

session = Session()
region_name = client.meta.region_name
credentials = client._get_credentials()

serializer = _create_crt_request_serializer(session, region_name)
s3_client = _create_crt_s3_client(
session, config, region_name, credentials
session, config, region_name, credentials, lock
)
return serializer, s3_client


def get_crt_s3_client(client, config):
global _CRT_S3_CLIENT
global _BOTOCORE_CRT_SERIALIZER
global CRT_S3_CLIENT
global BOTOCORE_CRT_SERIALIZER

with CLIENT_CREATION_LOCK:
if _CRT_S3_CLIENT is None:
if CRT_S3_CLIENT is None:
serializer, s3_client = _initialize_crt_transfer_primatives(
client, config
)
_BOTOCORE_CRT_SERIALIZER = serializer
_CRT_S3_CLIENT = s3_client
BOTOCORE_CRT_SERIALIZER = serializer
CRT_S3_CLIENT = s3_client

return _CRT_S3_CLIENT
return CRT_S3_CLIENT


class CRTS3Client:
Expand All @@ -125,15 +127,19 @@ def __init__(self, crt_client, process_lock, region, cred_provider):
def is_crt_compatible_request(client, crt_s3_client):
"""
Boto3 client must use same signing region and credentials
as the _CRT_S3_CLIENT singleton. Otherwise fallback to classic.
as the CRT_S3_CLIENT singleton. Otherwise fallback to classic.
"""
if crt_s3_client is None:
return False

is_same_region = client.meta.region_name == crt_s3_client.region
boto3_creds = client._get_credentials()
if boto3_creds is None:
return False

is_same_identity = compare_identity(
client._get_credentials(), crt_s3_client.cred_provider
boto3_creds.get_frozen_credentials(), crt_s3_client.cred_provider
)
is_same_region = client.meta.region_name == crt_s3_client.region
return is_same_region and is_same_identity


Expand All @@ -156,6 +162,6 @@ def create_crt_transfer_manager(client, config):
crt_s3_client = get_crt_s3_client(client, config)
if is_crt_compatible_request(client, crt_s3_client):
return CRTTransferManager(
crt_s3_client.crt_client, _BOTOCORE_CRT_SERIALIZER
crt_s3_client.crt_client, BOTOCORE_CRT_SERIALIZER
)
return None
91 changes: 75 additions & 16 deletions tests/unit/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,44 @@
from tests import mock, requires_crt

if HAS_CRT:
import awscrt.s3
from awscrt.s3 import CrossProcessLock as CrossProcessLockClass
from s3transfer.crt import BotocoreCRTCredentialsWrapper

import boto3.crt


@pytest.fixture
def mock_crt_process_lock(monkeypatch):
# The process lock is cached at the module layer whenever the
# cross process lock is successfully acquired. This patch ensures that
# test cases will start off with no previously cached process lock and
# if a cross process is instantiated/acquired it will be the mock that
# can be used for controlling lock behavior.
if HAS_CRT:
monkeypatch.setattr('s3transfer.crt.CRT_S3_PROCESS_LOCK', None)
with mock.patch('awscrt.s3.CrossProcessLock', spec=True) as mock_lock:
yield mock_lock
else:
# We cannot mock or use the lock without CRT support.
yield None


@pytest.fixture
def mock_crt_client_singleton(monkeypatch):
# Clear CRT state for each test
if HAS_CRT:
monkeypatch.setattr('boto3.crt.CRT_S3_CLIENT', None)
yield None


@pytest.fixture
def mock_serializer_singleton(monkeypatch):
# Clear CRT state for each test
if HAS_CRT:
monkeypatch.setattr('boto3.crt.BOTOCORE_CRT_SERIALIZER', None)
yield None


def create_test_client(service_name='s3', region_name="us-east-1"):
return boto3.client(
service_name,
Expand All @@ -43,21 +75,35 @@ def create_test_client(service_name='s3', region_name="us-east-1"):

class TestCRTTransferManager:
@requires_crt()
def test_create_crt_transfer_manager_with_lock_in_use(self):
with mock.patch('boto3.crt.acquire_crt_s3_process_lock') as lock:
lock.return_value = None

# Verify we can't create a second CRT client
tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None)
assert tm is None
def test_create_crt_transfer_manager_with_lock_in_use(
self,
mock_crt_process_lock,
mock_crt_client_singleton,
mock_serializer_singleton,
):
mock_crt_process_lock.return_value.acquire.side_effect = RuntimeError

# Verify we can't create a second CRT client
tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None)
assert tm is None

@requires_crt()
def test_create_crt_transfer_manager(self):
def test_create_crt_transfer_manager(
self,
mock_crt_process_lock,
mock_crt_client_singleton,
mock_serializer_singleton,
):
tm = boto3.crt.create_crt_transfer_manager(USW2_S3_CLIENT, None)
assert isinstance(tm, s3transfer.crt.CRTTransferManager)

@requires_crt()
def test_crt_singleton_is_returned_every_call(self):
def test_crt_singleton_is_returned_every_call(
self,
mock_crt_process_lock,
mock_crt_client_singleton,
mock_serializer_singleton,
):
first_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, None)
second_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, None)

Expand All @@ -66,7 +112,12 @@ def test_crt_singleton_is_returned_every_call(self):
assert first_s3_client.crt_client is second_s3_client.crt_client

@requires_crt()
def test_create_crt_transfer_manager_w_client_in_wrong_region(self):
def test_create_crt_transfer_manager_w_client_in_wrong_region(
self,
mock_crt_process_lock,
mock_crt_client_singleton,
mock_serializer_singleton,
):
"""Ensure we don't return the crt transfer manager if client is in
different region. The CRT isn't able to handle region redirects and
will consistently fail.
Expand Down Expand Up @@ -130,20 +181,28 @@ def no_credentials():
)

@requires_crt()
def test_get_crt_s3_client(self):
def test_get_crt_s3_client(
self,
mock_crt_process_lock,
mock_crt_client_singleton,
mock_serializer_singleton,
):
config = TransferConfig()
crt_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, config)
assert isinstance(crt_s3_client, boto3.crt.CRTS3Client)
assert isinstance(
crt_s3_client.process_lock, awscrt.s3.CrossProcessLock
)
assert isinstance(crt_s3_client.process_lock, CrossProcessLockClass)
assert crt_s3_client.region == "us-west-2"
assert isinstance(
crt_s3_client.cred_provider, BotocoreCRTCredentialsWrapper
)

@requires_crt()
def test_get_crt_s3_client_w_wrong_region(self):
def test_get_crt_s3_client_w_wrong_region(
self,
mock_crt_process_lock,
mock_crt_client_singleton,
mock_serializer_singleton,
):
config = TransferConfig()
crt_s3_client = boto3.crt.get_crt_s3_client(USW2_S3_CLIENT, config)
assert isinstance(crt_s3_client, boto3.crt.CRTS3Client)
Expand Down

0 comments on commit 2118dc9

Please sign in to comment.