diff --git a/src/rpdk/core/boto_helpers.py b/src/rpdk/core/boto_helpers.py index e29d795a..9a9337c1 100644 --- a/src/rpdk/core/boto_helpers.py +++ b/src/rpdk/core/boto_helpers.py @@ -1,6 +1,8 @@ import logging from datetime import datetime +import botocore.loaders +import botocore.regions from boto3 import Session as Boto3Session from botocore.exceptions import ClientError @@ -31,7 +33,11 @@ def _known_error(msg): def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None): - sts_client = session.client("sts") + sts_client = session.client( + "sts", + endpoint_url=get_service_endpoint("sts", session.region_name), + region_name=session.region_name, + ) if role_arn: session_name = "CloudFormationContractTest-{:%Y%m%d%H%M%S}".format( datetime.now() @@ -62,3 +68,11 @@ def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None): temp = response["Credentials"] creds = (temp["AccessKeyId"], temp["SecretAccessKey"], temp["SessionToken"]) return dict(zip(key_names, creds)) + + +def get_service_endpoint(service, region): + loader = botocore.loaders.create_loader() + data = loader.load_data("endpoints") + resolver = botocore.regions.EndpointResolver(data) + endpoint_data = resolver.construct_endpoint(service, region) + return "https://" + endpoint_data["hostname"] diff --git a/tests/test_boto_helpers.py b/tests/test_boto_helpers.py index c5bac329..aaf4a9ab 100644 --- a/tests/test_boto_helpers.py +++ b/tests/test_boto_helpers.py @@ -57,11 +57,16 @@ def test_get_temporary_credentials_has_token(): frozen.access_key = object() frozen.secret_key = object() frozen.token = object() + session.region_name = "us-east-2" creds = get_temporary_credentials(session) session.get_credentials.assert_called_once_with() - session.client.assert_called_once_with("sts") + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.us-east-2.amazonaws.com", + region_name="us-east-2", + ) assert len(creds) == 3 assert tuple(creds.keys()) == BOTO_CRED_KEYS @@ -85,11 +90,16 @@ def test_get_temporary_credentials_needs_token(): "SessionToken": token, } } + session.region_name = "us-east-2" creds = get_temporary_credentials(session, LOWER_CAMEL_CRED_KEYS) session.get_credentials.assert_called_once_with() - session.client.assert_called_once_with("sts") + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.us-east-2.amazonaws.com", + region_name="us-east-2", + ) client.get_session_token.assert_called_once_with() assert len(creds) == 3 @@ -113,12 +123,17 @@ def test_get_temporary_credentials_invalid_credentials(): }, "GetSessionToken", ) + session.region_name = "us-east-2" with pytest.raises(DownstreamError): get_temporary_credentials(session) session.get_credentials.assert_called_once_with() - session.client.assert_called_once_with("sts") + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.us-east-2.amazonaws.com", + region_name="us-east-2", + ) client.get_session_token.assert_called_once_with() @@ -136,11 +151,16 @@ def test_get_temporary_credentials_assume_role_fails(): }, "GetSessionToken", ) + session.region_name = "us-east-2" with pytest.raises(DownstreamError): get_temporary_credentials(session, role_arn=EXPECTED_ROLE) - session.client.assert_called_once_with("sts") + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.us-east-2.amazonaws.com", + region_name="us-east-2", + ) client.assume_role.assert_called_once_with( RoleArn=EXPECTED_ROLE, RoleSessionName=ANY ) @@ -161,10 +181,15 @@ def test_get_temporary_credentials_assume_role(): "SessionToken": token, } } + session.region_name = "cn-north-1" creds = get_temporary_credentials(session, LOWER_CAMEL_CRED_KEYS, EXPECTED_ROLE) - session.client.assert_called_once_with("sts") + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.cn-north-1.amazonaws.com.cn", + region_name="cn-north-1", + ) client.assume_role.assert_called_once_with( RoleArn=EXPECTED_ROLE, RoleSessionName=ANY )