diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index 161aed5bb5988..b76eb95def1ed 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -1,7 +1,12 @@ +import logging +import os from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from enum import Enum +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union import boto3 +import requests from boto3.session import Session from botocore.config import DEFAULT_TIMEOUT, Config from botocore.utils import fix_s3_host @@ -14,6 +19,8 @@ ) from datahub.configuration.source_common import EnvConfigMixin +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from mypy_boto3_dynamodb import DynamoDBClient from mypy_boto3_glue import GlueClient @@ -22,6 +29,26 @@ from mypy_boto3_sts import STSClient +class AwsEnvironment(Enum): + EC2 = "EC2" + ECS = "ECS" + EKS = "EKS" + LAMBDA = "LAMBDA" + APP_RUNNER = "APP_RUNNER" + BEANSTALK = "ELASTIC_BEANSTALK" + CLOUD_FORMATION = "CLOUD_FORMATION" + UNKNOWN = "UNKNOWN" + + +class AwsServicePrincipal(Enum): + LAMBDA = "lambda.amazonaws.com" + EKS = "eks.amazonaws.com" + APP_RUNNER = "apprunner.amazonaws.com" + ECS = "ecs.amazonaws.com" + ELASTIC_BEANSTALK = "elasticbeanstalk.amazonaws.com" + EC2 = "ec2.amazonaws.com" + + class AwsAssumeRoleConfig(PermissiveConfigModel): # Using the PermissiveConfigModel to allow the user to pass additional arguments. @@ -34,6 +61,163 @@ class AwsAssumeRoleConfig(PermissiveConfigModel): ) +def get_instance_metadata_token() -> Optional[str]: + """Get IMDSv2 token""" + try: + response = requests.put( + "http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + timeout=1, + ) + if response.status_code == HTTPStatus.OK: + return response.text + except requests.exceptions.RequestException: + logger.debug("Failed to get IMDSv2 token") + return None + + +def is_running_on_ec2() -> bool: + """Check if code is running on EC2 using IMDSv2""" + token = get_instance_metadata_token() + if not token: + return False + + try: + response = requests.get( + "http://169.254.169.254/latest/meta-data/instance-id", + headers={"X-aws-ec2-metadata-token": token}, + timeout=1, + ) + return response.status_code == HTTPStatus.OK + except requests.exceptions.RequestException: + return False + + +def detect_aws_environment() -> AwsEnvironment: + """ + Detect the AWS environment we're running in. + Order matters as some environments may have multiple indicators. + """ + # Check Lambda first as it's most specific + if os.getenv("AWS_LAMBDA_FUNCTION_NAME"): + if os.getenv("AWS_EXECUTION_ENV", "").startswith("CloudFormation"): + return AwsEnvironment.CLOUD_FORMATION + return AwsEnvironment.LAMBDA + + # Check EKS (IRSA) + if os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") and os.getenv("AWS_ROLE_ARN"): + return AwsEnvironment.EKS + + # Check App Runner + if os.getenv("AWS_APP_RUNNER_SERVICE_ID"): + return AwsEnvironment.APP_RUNNER + + # Check ECS + if os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv( + "ECS_CONTAINER_METADATA_URI" + ): + return AwsEnvironment.ECS + + # Check Elastic Beanstalk + if os.getenv("ELASTIC_BEANSTALK_ENVIRONMENT_NAME"): + return AwsEnvironment.BEANSTALK + + if is_running_on_ec2(): + return AwsEnvironment.EC2 + + return AwsEnvironment.UNKNOWN + + +def get_instance_role_arn() -> Optional[str]: + """Get role ARN from EC2 instance metadata using IMDSv2""" + token = get_instance_metadata_token() + if not token: + return None + + try: + response = requests.get( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/", + headers={"X-aws-ec2-metadata-token": token}, + timeout=1, + ) + if response.status_code == 200: + role_name = response.text.strip() + if role_name: + sts = boto3.client("sts") + identity = sts.get_caller_identity() + return identity.get("Arn") + except Exception as e: + logger.debug(f"Failed to get instance role ARN: {e}") + return None + + +def get_lambda_role_arn() -> Optional[str]: + """Get the Lambda function's role ARN""" + try: + function_name = os.getenv("AWS_LAMBDA_FUNCTION_NAME") + if not function_name: + return None + + lambda_client = boto3.client("lambda") + function_config = lambda_client.get_function_configuration( + FunctionName=function_name + ) + return function_config.get("Role") + except Exception as e: + logger.debug(f"Failed to get Lambda role ARN: {e}") + return None + + +def get_current_identity() -> Tuple[Optional[str], Optional[str]]: + """ + Get the current role ARN and source type based on the runtime environment. + Returns (role_arn, credential_source) + """ + env = detect_aws_environment() + + if env == AwsEnvironment.LAMBDA: + role_arn = get_lambda_role_arn() + return role_arn, AwsServicePrincipal.LAMBDA.value + + elif env == AwsEnvironment.EKS: + role_arn = os.getenv("AWS_ROLE_ARN") + return role_arn, AwsServicePrincipal.EKS.value + + elif env == AwsEnvironment.APP_RUNNER: + try: + sts = boto3.client("sts") + identity = sts.get_caller_identity() + return identity.get("Arn"), AwsServicePrincipal.APP_RUNNER.value + except Exception as e: + logger.debug(f"Failed to get App Runner role: {e}") + + elif env == AwsEnvironment.ECS: + try: + metadata_uri = os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv( + "ECS_CONTAINER_METADATA_URI" + ) + if metadata_uri: + response = requests.get(f"{metadata_uri}/task", timeout=1) + if response.status_code == HTTPStatus.OK: + task_metadata = response.json() + if "TaskARN" in task_metadata: + return ( + task_metadata.get("TaskARN"), + AwsServicePrincipal.ECS.value, + ) + except Exception as e: + logger.debug(f"Failed to get ECS task role: {e}") + + elif env == AwsEnvironment.BEANSTALK: + # Beanstalk uses EC2 instance metadata + return get_instance_role_arn(), AwsServicePrincipal.ELASTIC_BEANSTALK.value + + elif env == AwsEnvironment.EC2: + return get_instance_role_arn(), AwsServicePrincipal.EC2.value + + return None, None + + def assume_role( role: AwsAssumeRoleConfig, aws_region: Optional[str], @@ -95,7 +279,7 @@ class AwsConnectionConfig(ConfigModel): ) aws_profile: Optional[str] = Field( default=None, - description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used", + description="The [named profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html) to use from AWS credentials. Falls back to default profile if not specified and no access keys provided. Profiles are configured in ~/.aws/credentials or ~/.aws/config.", ) aws_region: Optional[str] = Field(None, description="AWS region code.") @@ -145,6 +329,7 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]: def get_session(self) -> Session: if self.aws_access_key_id and self.aws_secret_access_key: + # Explicit credentials take precedence session = Session( aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, @@ -152,38 +337,57 @@ def get_session(self) -> Session: region_name=self.aws_region, ) elif self.aws_profile: + # Named profile is second priority session = Session( region_name=self.aws_region, profile_name=self.aws_profile ) else: - # Use boto3's credential autodetection. + # Use boto3's credential autodetection session = Session(region_name=self.aws_region) - if self._normalized_aws_roles(): - # Use existing session credentials to start the chain of role assumption. - current_credentials = session.get_credentials() - credentials = { - "AccessKeyId": current_credentials.access_key, - "SecretAccessKey": current_credentials.secret_key, - "SessionToken": current_credentials.token, - } - - for role in self._normalized_aws_roles(): - if self._should_refresh_credentials(): - credentials = assume_role( - role, - self.aws_region, - credentials=credentials, + target_roles = self._normalized_aws_roles() + if target_roles: + current_role_arn, credential_source = get_current_identity() + + # Only assume role if: + # 1. We're not in a known AWS environment with a role, or + # 2. We need to assume a different role than our current one + should_assume_role = current_role_arn is None or any( + role.RoleArn != current_role_arn for role in target_roles + ) + + if should_assume_role: + env = detect_aws_environment() + logger.debug(f"Assuming role(s) from {env.value} environment") + + current_credentials = session.get_credentials() + if current_credentials is None: + raise ValueError("No credentials available for role assumption") + + credentials = { + "AccessKeyId": current_credentials.access_key, + "SecretAccessKey": current_credentials.secret_key, + "SessionToken": current_credentials.token, + } + + for role in target_roles: + if self._should_refresh_credentials(): + credentials = assume_role( + role=role, + aws_region=self.aws_region, + credentials=credentials, + ) + if isinstance(credentials["Expiration"], datetime): + self._credentials_expiration = credentials["Expiration"] + + session = Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + region_name=self.aws_region, ) - if isinstance(credentials["Expiration"], datetime): - self._credentials_expiration = credentials["Expiration"] - - session = Session( - aws_access_key_id=credentials["AccessKeyId"], - aws_secret_access_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - region_name=self.aws_region, - ) + else: + logger.debug(f"Using existing role from {credential_source}") return session diff --git a/metadata-ingestion/tests/unit/test_aws_common.py b/metadata-ingestion/tests/unit/test_aws_common.py new file mode 100644 index 0000000000000..9291fb91134b1 --- /dev/null +++ b/metadata-ingestion/tests/unit/test_aws_common.py @@ -0,0 +1,328 @@ +import json +import os +from unittest.mock import MagicMock, patch + +import boto3 +import pytest +from moto import mock_iam, mock_lambda, mock_sts + +from datahub.ingestion.source.aws.aws_common import ( + AwsConnectionConfig, + AwsEnvironment, + detect_aws_environment, + get_current_identity, + get_instance_metadata_token, + get_instance_role_arn, + is_running_on_ec2, +) + + +@pytest.fixture +def mock_aws_config(): + return AwsConnectionConfig( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + aws_region="us-east-1", + ) + + +class TestAwsCommon: + def test_environment_detection_no_environment(self): + """Test environment detection when no AWS environment is present""" + with patch.dict(os.environ, {}, clear=True): + assert detect_aws_environment() == AwsEnvironment.UNKNOWN + + def test_environment_detection_lambda(self): + """Test Lambda environment detection""" + with patch.dict(os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}): + assert detect_aws_environment() == AwsEnvironment.LAMBDA + + def test_environment_detection_lambda_cloudformation(self): + """Test CloudFormation Lambda environment detection""" + with patch.dict( + os.environ, + { + "AWS_LAMBDA_FUNCTION_NAME": "test-function", + "AWS_EXECUTION_ENV": "CloudFormation.xxx", + }, + ): + assert detect_aws_environment() == AwsEnvironment.CLOUD_FORMATION + + def test_environment_detection_eks(self): + """Test EKS environment detection""" + with patch.dict( + os.environ, + { + "AWS_WEB_IDENTITY_TOKEN_FILE": "/var/run/secrets/token", + "AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/test-role", + }, + ): + assert detect_aws_environment() == AwsEnvironment.EKS + + def test_environment_detection_app_runner(self): + """Test App Runner environment detection""" + with patch.dict(os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-id"}): + assert detect_aws_environment() == AwsEnvironment.APP_RUNNER + + def test_environment_detection_ecs(self): + """Test ECS environment detection""" + with patch.dict( + os.environ, {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4"} + ): + assert detect_aws_environment() == AwsEnvironment.ECS + + def test_environment_detection_beanstalk(self): + """Test Elastic Beanstalk environment detection""" + with patch.dict(os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"}): + assert detect_aws_environment() == AwsEnvironment.BEANSTALK + + @patch("requests.put") + def test_ec2_metadata_token(self, mock_put): + """Test EC2 metadata token retrieval""" + mock_put.return_value.status_code = 200 + mock_put.return_value.text = "token123" + + token = get_instance_metadata_token() + assert token == "token123" + + mock_put.assert_called_once_with( + "http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + timeout=1, + ) + + @patch("requests.put") + def test_ec2_metadata_token_failure(self, mock_put): + """Test EC2 metadata token failure case""" + mock_put.return_value.status_code = 404 + + token = get_instance_metadata_token() + assert token is None + + @patch("requests.get") + @patch("requests.put") + def test_is_running_on_ec2(self, mock_put, mock_get): + """Test EC2 instance detection with IMDSv2""" + mock_put.return_value.status_code = 200 + mock_put.return_value.text = "token123" + mock_get.return_value.status_code = 200 + + assert is_running_on_ec2() is True + + mock_put.assert_called_once_with( + "http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + timeout=1, + ) + mock_get.assert_called_once_with( + "http://169.254.169.254/latest/meta-data/instance-id", + headers={"X-aws-ec2-metadata-token": "token123"}, + timeout=1, + ) + + @patch("requests.get") + @patch("requests.put") + def test_is_running_on_ec2_failure(self, mock_put, mock_get): + """Test EC2 instance detection failure""" + mock_put.return_value.status_code = 404 + assert is_running_on_ec2() is False + + mock_put.return_value.status_code = 200 + mock_put.return_value.text = "token123" + mock_get.return_value.status_code = 404 + assert is_running_on_ec2() is False + + @mock_sts + @mock_lambda + @mock_iam + def test_get_current_identity_lambda(self): + """Test getting identity in Lambda environment""" + with patch.dict( + os.environ, + { + "AWS_LAMBDA_FUNCTION_NAME": "test-function", + "AWS_DEFAULT_REGION": "us-east-1", + }, + ): + # Create IAM role first with proper trust policy + iam_client = boto3.client("iam", region_name="us-east-1") + trust_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "lambda.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } + iam_client.create_role( + RoleName="test-role", AssumeRolePolicyDocument=json.dumps(trust_policy) + ) + + lambda_client = boto3.client("lambda", region_name="us-east-1") + lambda_client.create_function( + FunctionName="test-function", + Runtime="python3.8", + Role="arn:aws:iam::123456789012:role/test-role", + Handler="index.handler", + Code={"ZipFile": b"def handler(event, context): pass"}, + ) + + role_arn, source = get_current_identity() + assert source == "lambda.amazonaws.com" + assert role_arn == "arn:aws:iam::123456789012:role/test-role" + + @patch("requests.get") + @patch("requests.put") + @mock_sts + def test_get_instance_role_arn_success(self, mock_put, mock_get): + """Test getting EC2 instance role ARN""" + mock_put.return_value.status_code = 200 + mock_put.return_value.text = "token123" + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "test-role" + + with patch("boto3.client") as mock_boto: + mock_sts = MagicMock() + mock_sts.get_caller_identity.return_value = { + "Arn": "arn:aws:sts::123456789012:assumed-role/test-role/instance" + } + mock_boto.return_value = mock_sts + + role_arn = get_instance_role_arn() + assert ( + role_arn == "arn:aws:sts::123456789012:assumed-role/test-role/instance" + ) + + @mock_sts + def test_aws_connection_config_basic(self, mock_aws_config): + """Test basic AWS connection configuration""" + session = mock_aws_config.get_session() + creds = session.get_credentials() + assert creds.access_key == "test-key" + assert creds.secret_key == "test-secret" + + @mock_sts + def test_aws_connection_config_with_session_token(self): + """Test AWS connection with session token""" + config = AwsConnectionConfig( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + aws_session_token="test-token", + aws_region="us-east-1", + ) + + session = config.get_session() + creds = session.get_credentials() + assert creds.token == "test-token" + + @mock_sts + def test_aws_connection_config_role_assumption(self): + """Test AWS connection with role assumption""" + config = AwsConnectionConfig( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + aws_region="us-east-1", + aws_role="arn:aws:iam::123456789012:role/test-role", + ) + + with patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity: + mock_identity.return_value = (None, None) + session = config.get_session() + creds = session.get_credentials() + assert creds is not None + + @mock_sts + def test_aws_connection_config_skip_role_assumption(self): + """Test AWS connection skipping role assumption when already in role""" + config = AwsConnectionConfig( + aws_region="us-east-1", + aws_role="arn:aws:iam::123456789012:role/current-role", + ) + + with patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity: + mock_identity.return_value = ( + "arn:aws:iam::123456789012:role/current-role", + "ec2.amazonaws.com", + ) + session = config.get_session() + assert session is not None + + @mock_sts + def test_aws_connection_config_multiple_roles(self): + """Test AWS connection with multiple role assumption""" + config = AwsConnectionConfig( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + aws_region="us-east-1", + aws_role=[ + "arn:aws:iam::123456789012:role/role1", + "arn:aws:iam::123456789012:role/role2", + ], + ) + + with patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity: + mock_identity.return_value = (None, None) + session = config.get_session() + assert session is not None + + def test_aws_connection_config_validation_error(self): + """Test AWS connection validation""" + with patch.dict( + "os.environ", + { + "AWS_ACCESS_KEY_ID": "test-key", + # Deliberately missing AWS_SECRET_ACCESS_KEY + "AWS_DEFAULT_REGION": "us-east-1", + }, + clear=True, + ): + config = AwsConnectionConfig() # Let it pick up from environment + session = config.get_session() + with pytest.raises( + Exception, + match="Partial credentials found in env, missing: AWS_SECRET_ACCESS_KEY", + ): + session.get_credentials() + + @pytest.mark.parametrize( + "env_vars,expected_environment", + [ + ({}, AwsEnvironment.UNKNOWN), + ({"AWS_LAMBDA_FUNCTION_NAME": "test"}, AwsEnvironment.LAMBDA), + ( + { + "AWS_LAMBDA_FUNCTION_NAME": "test", + "AWS_EXECUTION_ENV": "CloudFormation", + }, + AwsEnvironment.CLOUD_FORMATION, + ), + ( + { + "AWS_WEB_IDENTITY_TOKEN_FILE": "/token", + "AWS_ROLE_ARN": "arn:aws:iam::123:role/test", + }, + AwsEnvironment.EKS, + ), + ({"AWS_APP_RUNNER_SERVICE_ID": "service-123"}, AwsEnvironment.APP_RUNNER), + ( + {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2"}, + AwsEnvironment.ECS, + ), + ( + {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"}, + AwsEnvironment.BEANSTALK, + ), + ], + ) + def test_environment_detection_parametrized(self, env_vars, expected_environment): + """Parametrized test for environment detection with different configurations""" + with patch.dict(os.environ, env_vars, clear=True): + assert detect_aws_environment() == expected_environment