diff --git a/pyright/alt-1/requirements-pinned.txt b/pyright/alt-1/requirements-pinned.txt index 9adabce9b43d1..194895a13f60f 100644 --- a/pyright/alt-1/requirements-pinned.txt +++ b/pyright/alt-1/requirements-pinned.txt @@ -171,6 +171,7 @@ multidict==6.1.0 multimethod==1.10 mypy==1.11.2 mypy-boto3-ecs==1.35.21 +mypy-boto3-emr==1.35.18 mypy-boto3-emr-serverless==1.35.25 mypy-boto3-glue==1.35.25 mypy-boto3-s3==1.35.32 diff --git a/pyright/master/requirements-pinned.txt b/pyright/master/requirements-pinned.txt index 549546fa6b1a3..b7c91d5997efb 100644 --- a/pyright/master/requirements-pinned.txt +++ b/pyright/master/requirements-pinned.txt @@ -360,6 +360,7 @@ msgpack==1.1.0 multidict==6.1.0 multimethod==1.10 mypy-boto3-ecs==1.35.21 +mypy-boto3-emr==1.35.18 mypy-boto3-emr-serverless==1.35.25 mypy-boto3-glue==1.35.25 mypy-boto3-s3==1.35.32 diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py index 3902a5a642f31..ef95c70ae21d3 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py @@ -1,5 +1,6 @@ from dagster_aws.pipes.clients import ( PipesECSClient, + PipesEMRClient, PipesEMRServerlessClient, PipesGlueClient, PipesLambdaClient, @@ -19,6 +20,7 @@ "PipesGlueClient", "PipesLambdaClient", "PipesECSClient", + "PipesEMRClient", "PipesS3ContextInjector", "PipesLambdaEventContextInjector", "PipesS3MessageReader", diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py index a6711bbf1ce82..e3895ac00d101 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py @@ -1,6 +1,13 @@ from dagster_aws.pipes.clients.ecs import PipesECSClient +from dagster_aws.pipes.clients.emr import PipesEMRClient from dagster_aws.pipes.clients.emr_serverless import PipesEMRServerlessClient from dagster_aws.pipes.clients.glue import PipesGlueClient from dagster_aws.pipes.clients.lambda_ import PipesLambdaClient -__all__ = ["PipesGlueClient", "PipesLambdaClient", "PipesECSClient", "PipesEMRServerlessClient"] +__all__ = [ + "PipesGlueClient", + "PipesLambdaClient", + "PipesECSClient", + "PipesEMRServerlessClient", + "PipesEMRClient", +] diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py new file mode 100644 index 0000000000000..2719f21625b93 --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py @@ -0,0 +1,339 @@ +import os +import sys +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast + +import boto3 +import dagster._check as check +from dagster import PipesClient +from dagster._annotations import public +from dagster._core.definitions.resource_annotation import TreatAsResourceParam +from dagster._core.errors import DagsterExecutionInterruptedError +from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.pipes.client import ( + PipesClientCompletedInvocation, + PipesContextInjector, + PipesMessageReader, +) +from dagster._core.pipes.utils import PipesEnvContextInjector, PipesSession, open_pipes_session + +from dagster_aws.emr.emr import EMR_CLUSTER_TERMINATED_STATES +from dagster_aws.pipes.message_readers import ( + PipesS3LogReader, + PipesS3MessageReader, + gzip_log_decode_fn, +) + +if TYPE_CHECKING: + from mypy_boto3_emr import EMRClient + from mypy_boto3_emr.literals import ClusterStateType + from mypy_boto3_emr.type_defs import ( + ConfigurationUnionTypeDef, + DescribeClusterOutputTypeDef, + RunJobFlowInputRequestTypeDef, + RunJobFlowOutputTypeDef, + ) + + +def add_configuration( + configurations: List["ConfigurationUnionTypeDef"], + configuration: "ConfigurationUnionTypeDef", +): + """Add a configuration to a list of EMR configurations, merging configurations with the same classification. + + This is necessary because EMR doesn't accept multiple configurations with the same classification. + """ + for existing_configuration in configurations: + if existing_configuration.get("Classification") is not None and existing_configuration.get( + "Classification" + ) == configuration.get("Classification"): + properties = {**existing_configuration.get("Properties", {})} + properties.update(properties) + + inner_configurations = cast( + List["ConfigurationUnionTypeDef"], existing_configuration.get("Configurations", []) + ) + + for inner_configuration in cast( + List["ConfigurationUnionTypeDef"], configuration.get("Configurations", []) + ): + add_configuration(inner_configurations, inner_configuration) + + existing_configuration["Properties"] = properties + existing_configuration["Configurations"] = inner_configurations + + break + else: + configurations.append(configuration) + + +class PipesEMRClient(PipesClient, TreatAsResourceParam): + """A pipes client for running jobs on AWS EMR. + + Args: + message_reader (Optional[PipesMessageReader]): A message reader to use to read messages + from the EMR jobs. + Recommended to use :py:class:`PipesS3MessageReader` with `expect_s3_message_writer` set to `True`. + client (Optional[boto3.client]): The boto3 EMR client used to interact with AWS EMR. + context_injector (Optional[PipesContextInjector]): A context injector to use to inject + context into AWS EMR job. Defaults to :py:class:`PipesEnvContextInjector`. + forward_termination (bool): Whether to cancel the EMR job if the Dagster process receives a termination signal. + wait_for_s3_logs_seconds (int): The number of seconds to wait for S3 logs to be written after execution completes. + """ + + def __init__( + self, + message_reader: PipesMessageReader, + client=None, + context_injector: Optional[PipesContextInjector] = None, + forward_termination: bool = True, + wait_for_s3_logs_seconds: int = 10, + ): + self._client = client or boto3.client("emr") + self._message_reader = message_reader + self._context_injector = context_injector or PipesEnvContextInjector() + self.forward_termination = check.bool_param(forward_termination, "forward_termination") + self.wait_for_s3_logs_seconds = wait_for_s3_logs_seconds + + @property + def client(self) -> "EMRClient": + return self._client + + @property + def context_injector(self) -> PipesContextInjector: + return self._context_injector + + @property + def message_reader(self) -> PipesMessageReader: + return self._message_reader + + @classmethod + def _is_dagster_maintained(cls) -> bool: + return True + + @public + def run( + self, + *, + context: OpExecutionContext, + run_job_flow_params: "RunJobFlowInputRequestTypeDef", + extras: Optional[Dict[str, Any]] = None, + ) -> PipesClientCompletedInvocation: + """Run a job on AWS EMR, enriched with the pipes protocol. + + Starts a new EMR cluster for each invocation. + + Args: + context (OpExecutionContext): The context of the currently executing Dagster op or asset. + run_job_flow_params (Optional[dict]): Parameters for the ``run_job_flow`` boto3 EMR client call. + See `Boto3 API Documentation `_ + extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session in the external process. + + Returns: + PipesClientCompletedInvocation: Wrapper containing results reported by the external + process. + """ + with open_pipes_session( + context=context, + message_reader=self.message_reader, + context_injector=self.context_injector, + extras=extras, + ) as session: + run_job_flow_params = self._enrich_params(session, run_job_flow_params) + start_response = self._start(context, session, run_job_flow_params) + try: + self._add_log_readers(context, start_response) + wait_response = self._wait_for_completion(context, start_response) + self._read_remaining_logs(context, wait_response) + return PipesClientCompletedInvocation(session) + + except DagsterExecutionInterruptedError: + if self.forward_termination: + context.log.warning( + "[pipes] Dagster process interrupted! Will terminate external EMR job." + ) + self._terminate(context, start_response) + raise + + def _enrich_params( + self, session: PipesSession, params: "RunJobFlowInputRequestTypeDef" + ) -> "RunJobFlowInputRequestTypeDef": + # add Pipes env variables + pipes_env_vars = session.get_bootstrap_env_vars() + + configurations = cast(List["ConfigurationUnionTypeDef"], params.get("Configurations", [])) + + # add all possible env vars to spark-defaults, spark-env, yarn-env, hadoop-env + # since we can't be sure which one will be used by the job + add_configuration( + configurations, + { + "Classification": "spark-defaults", + "Properties": { + f"spark.yarn.appMasterEnv.{var}": value for var, value in pipes_env_vars.items() + }, + }, + ) + + for classification in ["spark-env", "yarn-env", "hadoop-env"]: + add_configuration( + configurations, + { + "Classification": classification, + "Configurations": [ + { + "Classification": "export", + "Properties": pipes_env_vars, + } + ], + }, + ) + + params["Configurations"] = configurations + + tags = list(params.get("Tags", [])) + + for key, value in session.default_remote_invocation_info.items(): + tags.append({"Key": key, "Value": value}) + + params["Tags"] = tags + + return params + + def _start( + self, + context: OpExecutionContext, + session: PipesSession, + params: "RunJobFlowInputRequestTypeDef", + ) -> "RunJobFlowOutputTypeDef": + response = self._client.run_job_flow(**params) + + session.report_launched({"extras": response}) + + cluster_id = response["JobFlowId"] + + context.log.info(f"[pipes] EMR steps started in cluster {cluster_id}") + return response + + def _wait_for_completion( + self, context: OpExecutionContext, response: "RunJobFlowOutputTypeDef" + ) -> "DescribeClusterOutputTypeDef": + cluster_id = response["JobFlowId"] + self._client.get_waiter("cluster_running").wait(ClusterId=cluster_id) + context.log.info(f"[pipes] EMR cluster {cluster_id} running") + # now wait for the job to complete + self._client.get_waiter("cluster_terminated").wait(ClusterId=cluster_id) + + cluster = self._client.describe_cluster(ClusterId=cluster_id) + + state: ClusterStateType = cluster["Cluster"]["Status"]["State"] + + context.log.info(f"[pipes] EMR cluster {cluster_id} completed with state: {state}") + + if state in EMR_CLUSTER_TERMINATED_STATES: + context.log.error(f"[pipes] EMR job {cluster_id} failed") + raise Exception(f"[pipes] EMR job {cluster_id} failed:\n{cluster}") + + return cluster + + def _add_log_readers(self, context: OpExecutionContext, response: "RunJobFlowOutputTypeDef"): + cluster = self.client.describe_cluster(ClusterId=response["JobFlowId"]) + + cluster_id = cluster["Cluster"]["Id"] + logs_uri = cluster.get("Cluster", {}).get("LogUri", {}) + + if isinstance(self.message_reader, PipesS3MessageReader) and logs_uri is None: + context.log.warning( + "[pipes] LogUri is not set in the EMR cluster configuration. Won't be able to read logs." + ) + elif isinstance(self.message_reader, PipesS3MessageReader) and isinstance(logs_uri, str): + bucket = logs_uri.split("/")[2] + prefix = "/".join(logs_uri.split("/")[3:]) + + steps = self.client.list_steps(ClusterId=cluster_id) + + # forward stdout and stderr from each step + + for step in steps["Steps"]: + step_id = step["Id"] + + for stdio in ["stdout", "stderr"]: + # at this stage we can't know if this key will be created + # for example, if a step doesn't have any stdout/stderr logs + # the PipesS3LogReader won't be able to start + # this may result in some unnecessary warnings + # there is not much we can do about it except perform step logs reading + # after the job is completed, which is not ideal too + key = os.path.join(prefix, f"{cluster_id}/steps/{step_id}/{stdio}.gz") + + self.message_reader.add_log_reader( + log_reader=PipesS3LogReader( + client=self.message_reader.client, + bucket=bucket, + key=key, + decode_fn=gzip_log_decode_fn, + target_stream=sys.stdout if stdio == "stdout" else sys.stderr, + debug_info=f"reader for {stdio} of EMR step {step_id}", + ), + ) + + def _read_remaining_logs( + self, context: OpExecutionContext, response: "DescribeClusterOutputTypeDef" + ): + cluster_id = response["Cluster"]["Id"] + logs_uri = response.get("Cluster", {}).get("LogUri", {}) + + if isinstance(self.message_reader, PipesS3MessageReader) and isinstance(logs_uri, str): + bucket = logs_uri.split("/")[2] + prefix = "/".join(logs_uri.split("/")[3:]) + + # discover container (application) logs (e.g. Python logs) and forward all of them + # ex. /containers/application_1727881613116_0001/container_1727881613116_0001_01_000001/stdout.gz + containers_prefix = os.path.join(prefix, f"{cluster_id}/containers/") + + context.log.debug( + f"[pipes] Waiting for {self.wait_for_s3_logs_seconds} seconds to allow EMR to dump all logs to S3. " + "Consider increasing this value if some logs are missing." + ) + + time.sleep(self.wait_for_s3_logs_seconds) # give EMR a chance to dump all logs to S3 + + context.log.debug( + f"[pipes] Looking for application logs in s3://{os.path.join(bucket, containers_prefix)}" + ) + + all_keys = [ + obj["Key"] + for obj in self.message_reader.client.list_objects_v2( + Bucket=bucket, Prefix=containers_prefix + )["Contents"] + ] + + # filter keys which include stdout.gz or stderr.gz + + container_log_keys = {} + for key in all_keys: + if "stdout.gz" in key: + container_log_keys[key] = "stdout" + elif "stderr.gz" in key: + container_log_keys[key] = "stderr" + + # forward application logs + + for key, stdio in container_log_keys.items(): + container_id = key.split("/")[-2] + self.message_reader.add_log_reader( + log_reader=PipesS3LogReader( + client=self.message_reader.client, + bucket=bucket, + key=key, + decode_fn=gzip_log_decode_fn, + target_stream=sys.stdout if stdio == "stdout" else sys.stderr, + debug_info=f"log reader for container {container_id} {stdio}", + ), + ) + + def _terminate(self, context: OpExecutionContext, start_response: "RunJobFlowOutputTypeDef"): + cluster_id = start_response["JobFlowId"] + context.log.info(f"[pipes] Terminating EMR job {cluster_id}") + self._client.terminate_job_flows(JobFlowIds=[cluster_id]) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py index fd5ce2117a988..6f00364908a62 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py @@ -1,4 +1,5 @@ import base64 +import gzip import os import random import string @@ -53,6 +54,10 @@ def default_log_decode_fn(contents: bytes) -> str: return contents.decode("utf-8") +def gzip_log_decode_fn(contents: bytes) -> str: + return gzip.decompress(contents).decode("utf-8") + + class PipesS3LogReader(PipesChunkedLogReader): def __init__( self, diff --git a/python_modules/libraries/dagster-aws/ruff.toml b/python_modules/libraries/dagster-aws/ruff.toml index 244f818c1fc43..64c0e85362012 100644 --- a/python_modules/libraries/dagster-aws/ruff.toml +++ b/python_modules/libraries/dagster-aws/ruff.toml @@ -8,7 +8,10 @@ extend-select = [ [lint.flake8-tidy-imports] banned-module-level-imports = [ + "mypy_boto3_s3", + "mypy_boto3_logs", "mypy_boto3_ecs", "mypy_boto3_glue", - "mypy_boto3_emr_serverless" + "mypy_boto3_emr_serverless", + "mypy_boto3_emr" ] diff --git a/python_modules/libraries/dagster-aws/setup.py b/python_modules/libraries/dagster-aws/setup.py index 85ccfd520abc7..9923d70197e31 100644 --- a/python_modules/libraries/dagster-aws/setup.py +++ b/python_modules/libraries/dagster-aws/setup.py @@ -37,6 +37,7 @@ def get_version() -> str: python_requires=">=3.8,<3.13", install_requires=[ "boto3", + "boto3-stubs-lite[ecs,glue,emr,emr-serverless]", f"dagster{pin}", "packaging", "requests", @@ -45,7 +46,7 @@ def get_version() -> str: "redshift": ["psycopg2-binary"], "pyspark": ["dagster-pyspark"], "stubs": [ - "boto3-stubs-lite[ecs,glue,emr-serverless,s3]", + "boto3-stubs-lite[ecs,glue,emr-serverless,s3,emr]", ], "test": [ "botocore!=1.32.1",