Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingestion/aws-common): improved instance profile support for ec2, ecs, eks, lambda, beanstalk, app runner and cft roles #12139

Merged
merged 15 commits into from
Dec 21, 2024
243 changes: 217 additions & 26 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

import boto3
import requests
from boto3.session import Session
from botocore.config import DEFAULT_TIMEOUT, Config
from botocore.utils import fix_s3_host
Expand All @@ -14,6 +18,8 @@
)
from datahub.configuration.source_common import EnvConfigMixin

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mypy_boto3_dynamodb import DynamoDBClient
from mypy_boto3_glue import GlueClient
Expand All @@ -22,6 +28,17 @@
from mypy_boto3_sts import STSClient


class AwsEnvironment(Enum):
EC2 = "EC2"
ECS = "ECS"
EKS = "EKS"
LAMBDA = "LAMBDA"
APP_RUNNER = "APP_RUNNER"
BEANSTALK = "ELASTIC_BEANSTALK"
CLOUD_FORMATION = "CLOUD_FORMATION"
UNKNOWN = "UNKNOWN"


class AwsAssumeRoleConfig(PermissiveConfigModel):
# Using the PermissiveConfigModel to allow the user to pass additional arguments.

Expand All @@ -34,6 +51,160 @@ class AwsAssumeRoleConfig(PermissiveConfigModel):
)


def get_instance_metadata_token() -> Optional[str]:
"""Get IMDSv2 token"""
try:
response = requests.put(
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
timeout=1,
)
if response.status_code == 200:
return response.text
except requests.exceptions.RequestException:
logger.debug("Failed to get IMDSv2 token")
return None


def is_running_on_ec2() -> bool:
"""Check if code is running on EC2 using IMDSv2"""
token = get_instance_metadata_token()
if not token:
return False

try:
response = requests.get(
"http://169.254.169.254/latest/meta-data/instance-id",
headers={"X-aws-ec2-metadata-token": token},
timeout=1,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False


def detect_aws_environment() -> AwsEnvironment:
"""
Detect the AWS environment we're running in.
Order matters as some environments may have multiple indicators.
"""
# Check Lambda first as it's most specific
if os.getenv("AWS_LAMBDA_FUNCTION_NAME"):
if os.getenv("AWS_EXECUTION_ENV", "").startswith("CloudFormation"):
return AwsEnvironment.CLOUD_FORMATION
return AwsEnvironment.LAMBDA

# Check EKS (IRSA)
if os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") and os.getenv("AWS_ROLE_ARN"):
return AwsEnvironment.EKS

# Check App Runner
if os.getenv("AWS_APP_RUNNER_SERVICE_ID"):
return AwsEnvironment.APP_RUNNER

# Check ECS
if os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv(
"ECS_CONTAINER_METADATA_URI"
):
return AwsEnvironment.ECS

# Check Elastic Beanstalk
if os.getenv("ELASTIC_BEANSTALK_ENVIRONMENT_NAME"):
return AwsEnvironment.BEANSTALK

if is_running_on_ec2():
return AwsEnvironment.EC2

return AwsEnvironment.UNKNOWN


def get_instance_role_arn() -> Optional[str]:
"""Get role ARN from EC2 instance metadata using IMDSv2"""
token = get_instance_metadata_token()
if not token:
return None

try:
response = requests.get(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
headers={"X-aws-ec2-metadata-token": token},
timeout=1,
)
if response.status_code == 200:
role_name = response.text.strip()
if role_name:
sts = boto3.client("sts")
identity = sts.get_caller_identity()
return identity.get("Arn")
except Exception as e:
logger.debug(f"Failed to get instance role ARN: {e}")
return None


def get_lambda_role_arn() -> Optional[str]:
"""Get the Lambda function's role ARN"""
try:
function_name = os.getenv("AWS_LAMBDA_FUNCTION_NAME")
if not function_name:
return None

lambda_client = boto3.client("lambda")
function_config = lambda_client.get_function_configuration(
FunctionName=function_name
)
return function_config.get("Role")
except Exception as e:
logger.debug(f"Failed to get Lambda role ARN: {e}")
return None


def get_current_identity() -> Tuple[Optional[str], Optional[str]]:
"""
Get the current role ARN and source type based on the runtime environment.
Returns (role_arn, credential_source)
"""
env = detect_aws_environment()

if env == AwsEnvironment.LAMBDA:
role_arn = get_lambda_role_arn()
return role_arn, "lambda.amazonaws.com"

elif env == AwsEnvironment.EKS:
role_arn = os.getenv("AWS_ROLE_ARN")
return role_arn, "eks.amazonaws.com"
acrylJonny marked this conversation as resolved.
Show resolved Hide resolved

elif env == AwsEnvironment.APP_RUNNER:
try:
sts = boto3.client("sts")
identity = sts.get_caller_identity()
return identity.get("Arn"), "apprunner.amazonaws.com"
except Exception as e:
logger.debug(f"Failed to get App Runner role: {e}")

elif env == AwsEnvironment.ECS:
try:
metadata_uri = os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv(
"ECS_CONTAINER_METADATA_URI"
)
if metadata_uri:
response = requests.get(f"{metadata_uri}/task", timeout=1)
if response.status_code == 200:
acrylJonny marked this conversation as resolved.
Show resolved Hide resolved
task_metadata = response.json()
if "TaskARN" in task_metadata:
return task_metadata.get("TaskARN"), "ecs.amazonaws.com"
except Exception as e:
logger.debug(f"Failed to get ECS task role: {e}")

elif env == AwsEnvironment.BEANSTALK:
# Beanstalk uses EC2 instance metadata
return get_instance_role_arn(), "elasticbeanstalk.amazonaws.com"

elif env == AwsEnvironment.EC2:
return get_instance_role_arn(), "ec2.amazonaws.com"

return None, None


def assume_role(
role: AwsAssumeRoleConfig,
aws_region: Optional[str],
Expand Down Expand Up @@ -145,45 +316,65 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:

def get_session(self) -> Session:
if self.aws_access_key_id and self.aws_secret_access_key:
# Explicit credentials take precedence
session = Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.aws_region,
)
elif self.aws_profile:
# Named profile is second priority
session = Session(
region_name=self.aws_region, profile_name=self.aws_profile
)
else:
# Use boto3's credential autodetection.
# Use boto3's credential autodetection
session = Session(region_name=self.aws_region)

if self._normalized_aws_roles():
# Use existing session credentials to start the chain of role assumption.
current_credentials = session.get_credentials()
credentials = {
"AccessKeyId": current_credentials.access_key,
"SecretAccessKey": current_credentials.secret_key,
"SessionToken": current_credentials.token,
}

for role in self._normalized_aws_roles():
if self._should_refresh_credentials():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
target_roles = self._normalized_aws_roles()
if target_roles:
current_role_arn, credential_source = get_current_identity()

# Only assume role if:
# 1. We're not in a known AWS environment with a role, or
# 2. We need to assume a different role than our current one
should_assume_role = current_role_arn is None or any(
role.RoleArn != current_role_arn for role in target_roles
)

if should_assume_role:
env = detect_aws_environment()
logger.debug(f"Assuming role(s) from {env.value} environment")

current_credentials = session.get_credentials()
if current_credentials is None:
raise ValueError("No credentials available for role assumption")

credentials = {
"AccessKeyId": current_credentials.access_key,
"SecretAccessKey": current_credentials.secret_key,
"SessionToken": current_credentials.token,
}

for role in target_roles:
if self._should_refresh_credentials():
credentials = assume_role(
role=role,
aws_region=self.aws_region,
credentials=credentials,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]

session = Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]

session = Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region,
)
else:
logger.debug(f"Using existing role from {credential_source}")

return session

Expand Down
Loading
Loading