Skip to content

Commit

Permalink
Use regional sts endpoints (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
nina-ctrlv authored Jun 16, 2020
1 parent 5df4efe commit a250245
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
16 changes: 15 additions & 1 deletion src/rpdk/core/boto_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from datetime import datetime

import botocore.loaders
import botocore.regions
from boto3 import Session as Boto3Session
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -31,7 +33,11 @@ def _known_error(msg):


def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None):
sts_client = session.client("sts")
sts_client = session.client(
"sts",
endpoint_url=get_service_endpoint("sts", session.region_name),
region_name=session.region_name,
)
if role_arn:
session_name = "CloudFormationContractTest-{:%Y%m%d%H%M%S}".format(
datetime.now()
Expand Down Expand Up @@ -62,3 +68,11 @@ def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None):
temp = response["Credentials"]
creds = (temp["AccessKeyId"], temp["SecretAccessKey"], temp["SessionToken"])
return dict(zip(key_names, creds))


def get_service_endpoint(service, region):
loader = botocore.loaders.create_loader()
data = loader.load_data("endpoints")
resolver = botocore.regions.EndpointResolver(data)
endpoint_data = resolver.construct_endpoint(service, region)
return "https://" + endpoint_data["hostname"]
35 changes: 30 additions & 5 deletions tests/test_boto_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@ def test_get_temporary_credentials_has_token():
frozen.access_key = object()
frozen.secret_key = object()
frozen.token = object()
session.region_name = "us-east-2"

creds = get_temporary_credentials(session)

session.get_credentials.assert_called_once_with()
session.client.assert_called_once_with("sts")
session.client.assert_called_once_with(
"sts",
endpoint_url="https://sts.us-east-2.amazonaws.com",
region_name="us-east-2",
)

assert len(creds) == 3
assert tuple(creds.keys()) == BOTO_CRED_KEYS
Expand All @@ -85,11 +90,16 @@ def test_get_temporary_credentials_needs_token():
"SessionToken": token,
}
}
session.region_name = "us-east-2"

creds = get_temporary_credentials(session, LOWER_CAMEL_CRED_KEYS)

session.get_credentials.assert_called_once_with()
session.client.assert_called_once_with("sts")
session.client.assert_called_once_with(
"sts",
endpoint_url="https://sts.us-east-2.amazonaws.com",
region_name="us-east-2",
)
client.get_session_token.assert_called_once_with()

assert len(creds) == 3
Expand All @@ -113,12 +123,17 @@ def test_get_temporary_credentials_invalid_credentials():
},
"GetSessionToken",
)
session.region_name = "us-east-2"

with pytest.raises(DownstreamError):
get_temporary_credentials(session)

session.get_credentials.assert_called_once_with()
session.client.assert_called_once_with("sts")
session.client.assert_called_once_with(
"sts",
endpoint_url="https://sts.us-east-2.amazonaws.com",
region_name="us-east-2",
)
client.get_session_token.assert_called_once_with()


Expand All @@ -136,11 +151,16 @@ def test_get_temporary_credentials_assume_role_fails():
},
"GetSessionToken",
)
session.region_name = "us-east-2"

with pytest.raises(DownstreamError):
get_temporary_credentials(session, role_arn=EXPECTED_ROLE)

session.client.assert_called_once_with("sts")
session.client.assert_called_once_with(
"sts",
endpoint_url="https://sts.us-east-2.amazonaws.com",
region_name="us-east-2",
)
client.assume_role.assert_called_once_with(
RoleArn=EXPECTED_ROLE, RoleSessionName=ANY
)
Expand All @@ -161,10 +181,15 @@ def test_get_temporary_credentials_assume_role():
"SessionToken": token,
}
}
session.region_name = "cn-north-1"

creds = get_temporary_credentials(session, LOWER_CAMEL_CRED_KEYS, EXPECTED_ROLE)

session.client.assert_called_once_with("sts")
session.client.assert_called_once_with(
"sts",
endpoint_url="https://sts.cn-north-1.amazonaws.com.cn",
region_name="cn-north-1",
)
client.assume_role.assert_called_once_with(
RoleArn=EXPECTED_ROLE, RoleSessionName=ANY
)
Expand Down

0 comments on commit a250245

Please sign in to comment.