diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py index 2e16f64b4004f..77165104fdab8 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py @@ -10,33 +10,24 @@ from dagster._core.libraries import DagsterLibraryRegistry -from .databricks import DatabricksClient, DatabricksError, DatabricksJobRunner +from .databricks import ( + DatabricksClient as DatabricksClient, + DatabricksError as DatabricksError, + DatabricksJobRunner as DatabricksJobRunner, +) from .databricks_pyspark_step_launcher import ( - DatabricksConfig, - DatabricksPySparkStepLauncher, - databricks_pyspark_step_launcher, + DatabricksConfig as DatabricksConfig, + DatabricksPySparkStepLauncher as DatabricksPySparkStepLauncher, + databricks_pyspark_step_launcher as databricks_pyspark_step_launcher, ) from .ops import ( - create_databricks_run_now_op, - create_databricks_submit_run_op, + create_databricks_run_now_op as create_databricks_run_now_op, + create_databricks_submit_run_op as create_databricks_submit_run_op, ) from .resources import ( DatabricksClientResource as DatabricksClientResource, - databricks_client, + databricks_client as databricks_client, ) from .version import __version__ DagsterLibraryRegistry.register("dagster-databricks", __version__) - -__all__ = [ - "create_databricks_run_now_op", - "create_databricks_submit_run_op", - "databricks_client", - "DatabricksClient", - "DatabricksConfig", - "DatabricksError", - "DatabricksJobRunner", - "DatabricksPySparkStepLauncher", - "databricks_pyspark_step_launcher", - "DatabricksClientResource", -] diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/configs.py b/python_modules/libraries/dagster-databricks/dagster_databricks/configs.py index 30d92b9dd4892..30a8151c8a4cf 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/configs.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/configs.py @@ -9,10 +9,12 @@ - https://docs.databricks.com/dev-tools/api/latest/clusters.html - https://docs.databricks.com/dev-tools/api/latest/libraries.html """ +from typing import Mapping, Union + from dagster import Array, Bool, Enum, EnumValue, Field, Int, Permissive, Selector, Shape, String -def _define_autoscale(): +def _define_autoscale() -> Field: return Field( Shape( fields={ @@ -36,7 +38,7 @@ def _define_autoscale(): ) -def _define_size(): +def _define_size() -> Selector: num_workers = Field( Int, description=( @@ -49,7 +51,7 @@ def _define_size(): return Selector({"autoscale": _define_autoscale(), "num_workers": num_workers}) -def _define_custom_tags(): +def _define_custom_tags() -> Field: key = Field( String, description=( @@ -81,12 +83,12 @@ def _define_custom_tags(): ) -def _define_dbfs_storage_info(): +def _define_dbfs_storage_info() -> Field: destination = Field(String, description="DBFS destination, e.g. dbfs:/my/path") return Field(Shape(fields={"destination": destination}), description="DBFS storage information") -def _define_s3_storage_info(): +def _define_s3_storage_info() -> Field: destination = Field( String, description=( @@ -160,7 +162,7 @@ def _define_s3_storage_info(): ) -def _define_aws_attributes_conf(): +def _define_aws_attributes_conf() -> Field: return Field( Permissive( fields={ @@ -258,7 +260,7 @@ def _define_aws_attributes_conf(): ) -def _define_cluster_log_conf(): +def _define_cluster_log_conf() -> Field: return Field( Selector({"dbfs": _define_dbfs_storage_info(), "s3": _define_s3_storage_info()}), description=( @@ -276,7 +278,7 @@ def _define_init_script(): return Selector({"dbfs": _define_dbfs_storage_info(), "s3": _define_s3_storage_info()}) -def _define_node_types(): +def _define_node_types() -> Field: node_type_id = Field( String, description=( @@ -304,7 +306,7 @@ def _define_node_types(): ) -def _define_nodes(): +def _define_nodes() -> Field: instance_pool_id = Field( String, description=( @@ -324,7 +326,7 @@ def _define_nodes(): ) -def _define_new_cluster(): +def _define_new_cluster() -> Field: spark_version = Field( String, description=( @@ -416,7 +418,7 @@ def _define_new_cluster(): ) -def _define_cluster(): +def _define_cluster() -> Selector: existing_cluster_id = Field( String, description=( @@ -431,7 +433,7 @@ def _define_cluster(): return Selector({"new": _define_new_cluster(), "existing": existing_cluster_id}) -def _define_pypi_library(): +def _define_pypi_library() -> Field: package = Field( String, description=( @@ -457,7 +459,7 @@ def _define_pypi_library(): ) -def _define_maven_library(): +def _define_maven_library() -> Field: coordinates = Field( String, description=( @@ -490,7 +492,7 @@ def _define_maven_library(): ) -def _define_cran_library(): +def _define_cran_library() -> Field: package = Field( String, description="The name of the CRAN package to install. This field is required.", @@ -510,7 +512,7 @@ def _define_cran_library(): ) -def _define_libraries(): +def _define_libraries() -> Field: jar = Field( String, description=( @@ -564,7 +566,7 @@ def _define_libraries(): ) -def _define_submit_run_fields(): +def _define_submit_run_fields() -> Mapping[str, Union[Selector, Field]]: run_name = Field( String, description="An optional name for the run. The default value is Untitled", @@ -610,7 +612,7 @@ def _define_submit_run_fields(): } -def _define_notebook_task(): +def _define_notebook_task() -> Field: notebook_path = Field( String, description=( @@ -632,7 +634,7 @@ def _define_notebook_task(): return Field(Shape(fields={"notebook_path": notebook_path, "base_parameters": base_parameters})) -def _define_spark_jar_task(): +def _define_spark_jar_task() -> Field: main_class_name = Field( String, description=( @@ -652,7 +654,7 @@ def _define_spark_jar_task(): return Field(Shape(fields={"main_class_name": main_class_name, "parameters": parameters})) -def _define_spark_python_task(): +def _define_spark_python_task() -> Field: python_file = Field( String, description=( @@ -670,7 +672,7 @@ def _define_spark_python_task(): return Field(Shape(fields={"python_file": python_file, "parameters": parameters})) -def _define_spark_submit_task(): +def _define_spark_submit_task() -> Field: parameters = Field( [String], description="Command-line parameters passed to spark submit.", @@ -692,7 +694,7 @@ def _define_spark_submit_task(): ) -def _define_task(): +def _define_task() -> Field: return Field( Selector( { @@ -707,20 +709,20 @@ def _define_task(): ) -def define_databricks_submit_custom_run_config(): +def define_databricks_submit_custom_run_config() -> Field: fields = _define_submit_run_fields() fields["task"] = _define_task() return Field(Shape(fields=fields), description="Databricks job run configuration") -def define_databricks_submit_run_config(): +def define_databricks_submit_run_config() -> Field: return Field( Shape(fields=_define_submit_run_fields()), description="Databricks job run configuration", ) -def _define_secret_scope(): +def _define_secret_scope() -> Field: return Field( String, description="The Databricks secret scope containing the storage secrets.", @@ -728,7 +730,7 @@ def _define_secret_scope(): ) -def _define_s3_storage_credentials(): +def _define_s3_storage_credentials() -> Field: access_key_key = Field( String, description="The key of a Databricks secret containing the S3 access key ID.", @@ -751,7 +753,7 @@ def _define_s3_storage_credentials(): ) -def _define_adls2_storage_credentials(): +def _define_adls2_storage_credentials() -> Field: storage_account_name = Field( String, description="The name of the storage account used to access data.", @@ -774,7 +776,7 @@ def _define_adls2_storage_credentials(): ) -def define_databricks_storage_config(): +def define_databricks_storage_config() -> Field: return Field( Selector( { @@ -791,7 +793,7 @@ def define_databricks_storage_config(): ) -def define_databricks_env_variables(): +def define_databricks_env_variables() -> Field: return Field( Permissive(), description=( @@ -801,7 +803,7 @@ def define_databricks_env_variables(): ) -def define_databricks_secrets_config(): +def define_databricks_secrets_config() -> Field: name = Field( String, description="The environment variable name, e.g. `DATABRICKS_TOKEN`.", @@ -822,14 +824,14 @@ def define_databricks_secrets_config(): ) -def _define_accessor(): +def _define_accessor() -> Selector: return Selector( {"group_name": str, "user_name": str}, description="Group or User that shall access the target.", ) -def _define_databricks_job_permission(): +def _define_databricks_job_permission() -> Field: job_permission_levels = [ "NO_PERMISSIONS", "CAN_VIEW", @@ -850,7 +852,7 @@ def _define_databricks_job_permission(): ) -def _define_databricks_cluster_permission(): +def _define_databricks_cluster_permission() -> Field: cluster_permission_levels = ["NO_PERMISSIONS", "CAN_ATTACH_TO", "CAN_RESTART", "CAN_MANAGE"] return Field( { @@ -865,7 +867,7 @@ def _define_databricks_cluster_permission(): ) -def define_databricks_permissions(): +def define_databricks_permissions() -> Field: return Field( { "job_permissions": _define_databricks_job_permission(), diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/databricks.py b/python_modules/libraries/dagster-databricks/dagster_databricks/databricks.py index 3f0c58cafbbab..055688b45d699 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/databricks.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/databricks.py @@ -1,7 +1,7 @@ import base64 import logging import time -from typing import Any, Mapping, Optional +from typing import Any, Mapping, Optional, Tuple import dagster import dagster._check as check @@ -209,8 +209,10 @@ class DatabricksJobRunner: """Submits jobs created using Dagster config to Databricks, and monitors their progress. Attributes: - host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net - token (str): Databricks token + host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net. + token (str): Databricks authentication token. + poll_interval_sec (float): How often to poll Databricks for run status. + max_wait_time_sec (int): How long to wait for a run to complete before failing. """ def __init__( @@ -314,7 +316,9 @@ def submit_run(self, run_config: Mapping[str, Any], task: Mapping[str, Any]) -> } return JobsService(self.client.api_client).submit_run(**config)["run_id"] - def retrieve_logs_for_run_id(self, log: logging.Logger, databricks_run_id: int): + def retrieve_logs_for_run_id( + self, log: logging.Logger, databricks_run_id: int + ) -> Optional[Tuple[Optional[str], Optional[str]]]: """Retrieve the stdout and stderr logs for a run.""" api_client = self.client.api_client @@ -341,9 +345,9 @@ def retrieve_logs_for_run_id(self, log: logging.Logger, databricks_run_id: int): def wait_for_dbfs_logs( self, log: logging.Logger, - prefix, - cluster_id, - filename, + prefix: str, + cluster_id: str, + filename: str, waiter_delay: int = 10, waiter_max_attempts: int = 10, ) -> Optional[str]: diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py b/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py index 8720942d945a8..6920a9d3f48d3 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py @@ -6,6 +6,7 @@ import tempfile import time import zlib +from typing import Iterator, Sequence, cast from dagster import ( Bool, @@ -18,6 +19,10 @@ from dagster._core.definitions.resource_definition import dagster_maintained_resource from dagster._core.definitions.step_launcher import StepLauncher from dagster._core.errors import raise_execution_interrupts +from dagster._core.events import DagsterEvent +from dagster._core.events.log import EventLogEntry +from dagster._core.execution.context.init import InitResourceContext +from dagster._core.execution.context.system import StepExecutionContext from dagster._core.execution.plan.external_step import ( PICKLED_EVENTS_FILE_NAME, PICKLED_STEP_RUN_REF_FILE_NAME, @@ -161,7 +166,9 @@ ), } ) -def databricks_pyspark_step_launcher(context): +def databricks_pyspark_step_launcher( + context: InitResourceContext, +) -> "DatabricksPySparkStepLauncher": """Resource for running ops as a Databricks Job. When this resource is used, the op will be executed in Databricks using the 'Run Submit' @@ -233,7 +240,7 @@ def __init__( add_dagster_env_variables, "add_dagster_env_variables" ) - def launch_step(self, step_context): + def launch_step(self, step_context: StepExecutionContext) -> Iterator[DagsterEvent]: step_run_ref = step_context_to_step_run_ref( step_context, self.local_dagster_job_package_path ) @@ -290,7 +297,9 @@ def log_compute_logs(self, log, run_id, step_key): f" {step_key}. Check the databricks console for more info." ) - def step_events_iterator(self, step_context, step_key: str, databricks_run_id: int): + def step_events_iterator( + self, step_context: StepExecutionContext, step_key: str, databricks_run_id: int + ) -> Iterator[DagsterEvent]: """The launched Databricks job writes all event records to a specific dbfs file. This iterator regularly reads the contents of the file, adds any events that have not yet been seen to the instance, and yields any DagsterEvents. @@ -330,19 +339,24 @@ def step_events_iterator(self, step_context, step_key: str, databricks_run_id: i # write each event from the DataBricks instance to the local instance step_context.instance.handle_new_event(event) if event.is_dagster_event: - yield event.dagster_event + yield event.get_dagster_event() processed_events = len(all_events) step_context.log.info(f"Databricks run {databricks_run_id} completed.") - def get_step_events(self, run_id: str, step_key: str, retry_number: int): + def get_step_events( + self, run_id: str, step_key: str, retry_number: int + ) -> Sequence[EventLogEntry]: path = self._dbfs_path(run_id, step_key, f"{retry_number}_{PICKLED_EVENTS_FILE_NAME}") - def _get_step_records(): + def _get_step_records() -> Sequence[EventLogEntry]: serialized_records = self.databricks_runner.client.read_file(path) if not serialized_records: return [] - return deserialize_value(pickle.loads(gzip.decompress(serialized_records))) + return cast( + Sequence[EventLogEntry], + deserialize_value(pickle.loads(gzip.decompress(serialized_records))), + ) try: # reading from dbfs while it writes can be flaky diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ops.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ops.py index a871bb4ebb1e9..40b22b5f821c3 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ops.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ops.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, cast from dagster import ( In, @@ -112,10 +112,13 @@ def _databricks_run_now_op( databricks: DatabricksClient = getattr(context.resources, databricks_resource_key) jobs_service = JobsService(databricks.api_client) - run_id: int = jobs_service.run_now( - job_id=databricks_job_id, - **(databricks_job_configuration or {}), - )["run_id"] + run_id = cast( + int, + jobs_service.run_now( + job_id=databricks_job_id, + **(databricks_job_configuration or {}), + )["run_id"], + ) get_run_response: dict = jobs_service.get_run(run_id=run_id)