Skip to content

Commit

Permalink
Adding headers in STS Calls for Confused Deputy
Browse files Browse the repository at this point in the history
  • Loading branch information
saieshwarm committed Mar 7, 2024
1 parent f0e3769 commit 9197399
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 46 deletions.
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ disable=
duplicate-code, # finds dupes between tests and plugins
too-few-public-methods, # triggers when inheriting
ungrouped-imports, # clashes with isort
W0613 # Unused argument 'kwargs'

[BASIC]

Expand All @@ -23,4 +24,5 @@ indent-string=' '
max-line-length=160

[DESIGN]
max-locals=16
max-locals=17
max-args=6
22 changes: 20 additions & 2 deletions src/rpdk/core/boto_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,35 @@ def _known_error(msg):
return session


def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None):
def get_temporary_credentials(
session, key_names=BOTO_CRED_KEYS, role_arn=None, headers=None
):
sts_client = session.client(
"sts",
endpoint_url=get_service_endpoint("sts", session.region_name),
region_name=session.region_name,
)
check_keys = {"account_id", "source_arn"}
if (
headers
and check_keys.issubset(headers.keys())
and headers["account_id"]
and headers["source_arn"]
):
# Inject headers through the event system.
def inject_confused_deputy_headers(params, **kwargs):
params["headers"]["x-amz-source-account"] = headers["account_id"]
params["headers"]["x-amz-source-arn"] = headers["source_arn"]

sts_client.meta.events.register("before-call", inject_confused_deputy_headers)
LOG.info(headers)
if role_arn:
session_name = f"CloudFormationContractTest-{datetime.now():%Y%m%d%H%M%S}"
try:
response = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=session_name, DurationSeconds=900
RoleArn=role_arn,
RoleSessionName=session_name,
DurationSeconds=900,
)
except ClientError:
# pylint: disable=W1201
Expand Down
10 changes: 7 additions & 3 deletions src/rpdk/core/contract/hook_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
type_name=None,
log_group_name=None,
log_role_arn=None,
headers=None,
docker_image=None,
typeconfig=None,
executable_entrypoint=None,
Expand All @@ -69,9 +70,12 @@ def __init__(
self._log_group_name = log_group_name
self._log_role_arn = log_role_arn
self.region = region
self._headers = headers
self.account = get_account(
self._session,
get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn),
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers
),
)
self._function_name = function_name
if endpoint.startswith("http://"):
Expand Down Expand Up @@ -396,11 +400,11 @@ def _make_payload(
self.account,
invocation_point,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
),
self._log_group_name,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers
),
self.generate_token(),
target_model,
Expand Down
12 changes: 8 additions & 4 deletions src/rpdk/core/contract/resource_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
type_name=None,
log_group_name=None,
log_role_arn=None,
headers=None,
docker_image=None,
typeconfig=None,
executable_entrypoint=None,
Expand All @@ -182,9 +183,12 @@ def __init__(
self._log_group_name = log_group_name
self._log_role_arn = log_role_arn
self.region = region
self._headers = headers
self.account = get_account(
self._session,
get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn),
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers
),
)
self._function_name = function_name
if endpoint.startswith("http://"):
Expand Down Expand Up @@ -674,12 +678,12 @@ def _make_payload(
self.account,
action,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
),
self._type_name,
self._log_group_name,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers
),
self.generate_token(),
type_configuration=type_configuration,
Expand Down Expand Up @@ -794,7 +798,7 @@ def call(self, action, current_model, previous_model=None, **kwargs):
request["callbackContext"] = response.get("callbackContext")
# refresh credential for every handler invocation
request["requestData"]["callerCredentials"] = get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
)

response = self._call(request)
Expand Down
55 changes: 45 additions & 10 deletions src/rpdk/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ def temporary_ini_file():
yield str(path)


def get_cloudformation_exports(region_name, endpoint_url, role_arn, profile_name):
def get_cloudformation_exports(
region_name, endpoint_url, role_arn, profile_name, headers
):
session = create_sdk_session(region_name, profile_name)
temp_credentials = get_temporary_credentials(session, role_arn=role_arn)
temp_credentials = get_temporary_credentials(
session, role_arn=role_arn, headers=headers
)
cfn_client = session.client(
"cloudformation", endpoint_url=endpoint_url, **temp_credentials
)
Expand Down Expand Up @@ -132,13 +136,13 @@ def __retrieve_args(match):


def render_template(
overrides_string, region_name, endpoint_url, role_arn, profile_name
overrides_string, region_name, endpoint_url, role_arn, profile_name, headers
):
regex = r"{{([-A-Za-z0-9:\s]+?)}}"
variables = set(str(match).strip() for match in re.findall(regex, overrides_string))
if variables:
exports = get_cloudformation_exports(
region_name, endpoint_url, role_arn, profile_name
region_name, endpoint_url, role_arn, profile_name, headers
)
invalid_exports = variables - exports.keys()
if len(invalid_exports) > 0:
Expand Down Expand Up @@ -166,15 +170,20 @@ def filter_overrides(overrides, project):
return overrides


def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):
def get_overrides(root, region_name, endpoint_url, role_arn, profile_name, headers):
if not root:
return empty_override()

path = root / "overrides.json"
try:
with path.open("r", encoding="utf-8") as f:
overrides_raw = render_template(
f.read(), region_name, endpoint_url, role_arn, profile_name
f.read(),
region_name,
endpoint_url,
role_arn,
profile_name,
headers=headers,
)
except FileNotFoundError:
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
Expand Down Expand Up @@ -203,15 +212,22 @@ def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):

# pylint: disable=R0914
# flake8: noqa: C901
def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):
def get_hook_overrides(
root, region_name, endpoint_url, role_arn, profile_name, headers
):
if not root:
return empty_hook_override()

path = root / "overrides.json"
try:
with path.open("r", encoding="utf-8") as f:
overrides_raw = render_template(
f.read(), region_name, endpoint_url, role_arn, profile_name
f.read(),
region_name,
endpoint_url,
role_arn,
profile_name,
headers=headers,
)
except FileNotFoundError:
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
Expand Down Expand Up @@ -258,7 +274,7 @@ def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):


# pylint: disable=R0914,too-many-arguments
def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name, headers):
inputs = {}
if not root:
return None
Expand All @@ -280,7 +296,12 @@ def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
file_path = path / file
with file_path.open("r", encoding="utf-8") as f:
overrides_raw = render_template(
f.read(), region_name, endpoint_url, role_arn, profile_name
f.read(),
region_name,
endpoint_url,
role_arn,
profile_name,
headers=headers,
)
overrides = {}
for pointer, obj in overrides_raw.items():
Expand Down Expand Up @@ -355,6 +376,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
project.type_name,
args.log_group_name,
args.log_role_arn,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
executable_entrypoint=project.executable_entrypoint,
docker_image=args.docker_image,
typeconfig=args.typeconfig,
Expand All @@ -378,6 +400,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
project.type_name,
args.log_group_name,
args.log_role_arn,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
typeconfig=args.typeconfig,
executable_entrypoint=project.executable_entrypoint,
docker_image=args.docker_image,
Expand All @@ -402,6 +425,7 @@ def test(args):
args.cloudformation_endpoint_url,
args.role_arn,
args.profile,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
)
else:
overrides = get_overrides(
Expand All @@ -410,6 +434,7 @@ def test(args):
args.cloudformation_endpoint_url,
args.role_arn,
args.profile,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
)
filter_overrides(overrides, project)

Expand All @@ -422,6 +447,7 @@ def test(args):
index,
args.role_arn,
args.profile,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
)
if not inputs:
break
Expand Down Expand Up @@ -509,6 +535,15 @@ def setup_subparser(subparsers, parents):
" '~/.cfn-cli/typeConfiguration.json.'"
),
)
parser.add_argument(
"--source-account",
help="Source Account key used for Assume Role to Run Contract Tests",
)

parser.add_argument(
"--source-arn",
help="Source Type Version Arn key used for Assume Role to Run Contract Tests",
)


def _sam_arguments(parser):
Expand Down
6 changes: 3 additions & 3 deletions tests/contract/test_hook_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def hook_client():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == SCHEMA_
Expand Down Expand Up @@ -179,7 +179,7 @@ def hook_client_inputs():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == SCHEMA_
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_init_sam_cli_client():
mock_sesh.client.assert_called_once_with(
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client.account == ACCOUNT

Expand Down
14 changes: 7 additions & 7 deletions tests/contract/test_resource_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def resource_client():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == EMPTY_SCHEMA
Expand Down Expand Up @@ -214,7 +214,7 @@ def resource_client_no_handler():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == {}
Expand Down Expand Up @@ -254,7 +254,7 @@ def resource_client_inputs():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})

assert client._function_name == DEFAULT_FUNCTION
Expand Down Expand Up @@ -299,7 +299,7 @@ def resource_client_inputs_schema(request):
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})

assert client._function_name == DEFAULT_FUNCTION
Expand Down Expand Up @@ -344,7 +344,7 @@ def resource_client_inputs_composite_key():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})

assert client._function_name == DEFAULT_FUNCTION
Expand Down Expand Up @@ -384,7 +384,7 @@ def resource_client_inputs_property_transform():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == SCHEMA_WITH_PROPERTY_TRANSFORM
Expand Down Expand Up @@ -693,7 +693,7 @@ def test_init_sam_cli_client():
mock_sesh.client.assert_called_once_with(
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client.account == ACCOUNT

Expand Down
Loading

0 comments on commit 9197399

Please sign in to comment.