Skip to content

Commit

Permalink
Update dagster-databricks to use new official databricks Python SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Aug 21, 2023
1 parent 45cdf4c commit 22acced
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,24 @@

from dagster._core.libraries import DagsterLibraryRegistry

from .databricks import DatabricksClient, DatabricksError, DatabricksJobRunner
from .databricks import (
DatabricksClient as DatabricksClient,
DatabricksError as DatabricksError,
DatabricksJobRunner as DatabricksJobRunner,
)
from .databricks_pyspark_step_launcher import (
DatabricksConfig,
DatabricksPySparkStepLauncher,
databricks_pyspark_step_launcher,
DatabricksConfig as DatabricksConfig,
DatabricksPySparkStepLauncher as DatabricksPySparkStepLauncher,
databricks_pyspark_step_launcher as databricks_pyspark_step_launcher,
)
from .ops import (
create_databricks_run_now_op,
create_databricks_submit_run_op,
create_databricks_run_now_op as create_databricks_run_now_op,
create_databricks_submit_run_op as create_databricks_submit_run_op,
)
from .resources import (
DatabricksClientResource as DatabricksClientResource,
databricks_client,
databricks_client as databricks_client,
)
from .version import __version__

DagsterLibraryRegistry.register("dagster-databricks", __version__)

__all__ = [
"create_databricks_run_now_op",
"create_databricks_submit_run_op",
"databricks_client",
"DatabricksClient",
"DatabricksConfig",
"DatabricksError",
"DatabricksJobRunner",
"DatabricksPySparkStepLauncher",
"databricks_pyspark_step_launcher",
"DatabricksClientResource",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
- https://docs.databricks.com/dev-tools/api/latest/clusters.html
- https://docs.databricks.com/dev-tools/api/latest/libraries.html
"""
from typing import Mapping, Union

from dagster import Array, Bool, Enum, EnumValue, Field, Int, Permissive, Selector, Shape, String


def _define_autoscale():
def _define_autoscale() -> Field:
return Field(
Shape(
fields={
Expand All @@ -36,7 +38,7 @@ def _define_autoscale():
)


def _define_size():
def _define_size() -> Selector:
num_workers = Field(
Int,
description=(
Expand All @@ -49,7 +51,7 @@ def _define_size():
return Selector({"autoscale": _define_autoscale(), "num_workers": num_workers})


def _define_custom_tags():
def _define_custom_tags() -> Field:
key = Field(
String,
description=(
Expand Down Expand Up @@ -81,12 +83,12 @@ def _define_custom_tags():
)


def _define_dbfs_storage_info():
def _define_dbfs_storage_info() -> Field:
destination = Field(String, description="DBFS destination, e.g. dbfs:/my/path")
return Field(Shape(fields={"destination": destination}), description="DBFS storage information")


def _define_s3_storage_info():
def _define_s3_storage_info() -> Field:
destination = Field(
String,
description=(
Expand Down Expand Up @@ -160,7 +162,7 @@ def _define_s3_storage_info():
)


def _define_aws_attributes_conf():
def _define_aws_attributes_conf() -> Field:
return Field(
Permissive(
fields={
Expand Down Expand Up @@ -258,7 +260,7 @@ def _define_aws_attributes_conf():
)


def _define_cluster_log_conf():
def _define_cluster_log_conf() -> Field:
return Field(
Selector({"dbfs": _define_dbfs_storage_info(), "s3": _define_s3_storage_info()}),
description=(
Expand All @@ -276,7 +278,7 @@ def _define_init_script():
return Selector({"dbfs": _define_dbfs_storage_info(), "s3": _define_s3_storage_info()})


def _define_node_types():
def _define_node_types() -> Field:
node_type_id = Field(
String,
description=(
Expand Down Expand Up @@ -304,7 +306,7 @@ def _define_node_types():
)


def _define_nodes():
def _define_nodes() -> Field:
instance_pool_id = Field(
String,
description=(
Expand All @@ -324,7 +326,7 @@ def _define_nodes():
)


def _define_new_cluster():
def _define_new_cluster() -> Field:
spark_version = Field(
String,
description=(
Expand Down Expand Up @@ -416,7 +418,7 @@ def _define_new_cluster():
)


def _define_cluster():
def _define_cluster() -> Selector:
existing_cluster_id = Field(
String,
description=(
Expand All @@ -431,7 +433,7 @@ def _define_cluster():
return Selector({"new": _define_new_cluster(), "existing": existing_cluster_id})


def _define_pypi_library():
def _define_pypi_library() -> Field:
package = Field(
String,
description=(
Expand All @@ -457,7 +459,7 @@ def _define_pypi_library():
)


def _define_maven_library():
def _define_maven_library() -> Field:
coordinates = Field(
String,
description=(
Expand Down Expand Up @@ -490,7 +492,7 @@ def _define_maven_library():
)


def _define_cran_library():
def _define_cran_library() -> Field:
package = Field(
String,
description="The name of the CRAN package to install. This field is required.",
Expand All @@ -510,7 +512,7 @@ def _define_cran_library():
)


def _define_libraries():
def _define_libraries() -> Field:
jar = Field(
String,
description=(
Expand Down Expand Up @@ -564,7 +566,7 @@ def _define_libraries():
)


def _define_submit_run_fields():
def _define_submit_run_fields() -> Mapping[str, Union[Selector, Field]]:
run_name = Field(
String,
description="An optional name for the run. The default value is Untitled",
Expand Down Expand Up @@ -610,7 +612,7 @@ def _define_submit_run_fields():
}


def _define_notebook_task():
def _define_notebook_task() -> Field:
notebook_path = Field(
String,
description=(
Expand All @@ -632,7 +634,7 @@ def _define_notebook_task():
return Field(Shape(fields={"notebook_path": notebook_path, "base_parameters": base_parameters}))


def _define_spark_jar_task():
def _define_spark_jar_task() -> Field:
main_class_name = Field(
String,
description=(
Expand All @@ -652,7 +654,7 @@ def _define_spark_jar_task():
return Field(Shape(fields={"main_class_name": main_class_name, "parameters": parameters}))


def _define_spark_python_task():
def _define_spark_python_task() -> Field:
python_file = Field(
String,
description=(
Expand All @@ -670,7 +672,7 @@ def _define_spark_python_task():
return Field(Shape(fields={"python_file": python_file, "parameters": parameters}))


def _define_spark_submit_task():
def _define_spark_submit_task() -> Field:
parameters = Field(
[String],
description="Command-line parameters passed to spark submit.",
Expand All @@ -692,7 +694,7 @@ def _define_spark_submit_task():
)


def _define_task():
def _define_task() -> Field:
return Field(
Selector(
{
Expand All @@ -707,28 +709,28 @@ def _define_task():
)


def define_databricks_submit_custom_run_config():
def define_databricks_submit_custom_run_config() -> Field:
fields = _define_submit_run_fields()
fields["task"] = _define_task()
return Field(Shape(fields=fields), description="Databricks job run configuration")


def define_databricks_submit_run_config():
def define_databricks_submit_run_config() -> Field:
return Field(
Shape(fields=_define_submit_run_fields()),
description="Databricks job run configuration",
)


def _define_secret_scope():
def _define_secret_scope() -> Field:
return Field(
String,
description="The Databricks secret scope containing the storage secrets.",
is_required=True,
)


def _define_s3_storage_credentials():
def _define_s3_storage_credentials() -> Field:
access_key_key = Field(
String,
description="The key of a Databricks secret containing the S3 access key ID.",
Expand All @@ -751,7 +753,7 @@ def _define_s3_storage_credentials():
)


def _define_adls2_storage_credentials():
def _define_adls2_storage_credentials() -> Field:
storage_account_name = Field(
String,
description="The name of the storage account used to access data.",
Expand All @@ -774,7 +776,7 @@ def _define_adls2_storage_credentials():
)


def define_databricks_storage_config():
def define_databricks_storage_config() -> Field:
return Field(
Selector(
{
Expand All @@ -791,7 +793,7 @@ def define_databricks_storage_config():
)


def define_databricks_env_variables():
def define_databricks_env_variables() -> Field:
return Field(
Permissive(),
description=(
Expand All @@ -801,7 +803,7 @@ def define_databricks_env_variables():
)


def define_databricks_secrets_config():
def define_databricks_secrets_config() -> Field:
name = Field(
String,
description="The environment variable name, e.g. `DATABRICKS_TOKEN`.",
Expand All @@ -822,14 +824,14 @@ def define_databricks_secrets_config():
)


def _define_accessor():
def _define_accessor() -> Selector:
return Selector(
{"group_name": str, "user_name": str},
description="Group or User that shall access the target.",
)


def _define_databricks_job_permission():
def _define_databricks_job_permission() -> Field:
job_permission_levels = [
"NO_PERMISSIONS",
"CAN_VIEW",
Expand All @@ -850,7 +852,7 @@ def _define_databricks_job_permission():
)


def _define_databricks_cluster_permission():
def _define_databricks_cluster_permission() -> Field:
cluster_permission_levels = ["NO_PERMISSIONS", "CAN_ATTACH_TO", "CAN_RESTART", "CAN_MANAGE"]
return Field(
{
Expand All @@ -865,7 +867,7 @@ def _define_databricks_cluster_permission():
)


def define_databricks_permissions():
def define_databricks_permissions() -> Field:
return Field(
{
"job_permissions": _define_databricks_job_permission(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import logging
import time
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, Tuple

import dagster
import dagster._check as check
Expand Down Expand Up @@ -209,8 +209,10 @@ class DatabricksJobRunner:
"""Submits jobs created using Dagster config to Databricks, and monitors their progress.
Attributes:
host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net
token (str): Databricks token
host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net.
token (str): Databricks authentication token.
poll_interval_sec (float): How often to poll Databricks for run status.
max_wait_time_sec (int): How long to wait for a run to complete before failing.
"""

def __init__(
Expand Down Expand Up @@ -314,7 +316,9 @@ def submit_run(self, run_config: Mapping[str, Any], task: Mapping[str, Any]) ->
}
return JobsService(self.client.api_client).submit_run(**config)["run_id"]

def retrieve_logs_for_run_id(self, log: logging.Logger, databricks_run_id: int):
def retrieve_logs_for_run_id(
self, log: logging.Logger, databricks_run_id: int
) -> Optional[Tuple[Optional[str], Optional[str]]]:
"""Retrieve the stdout and stderr logs for a run."""
api_client = self.client.api_client

Expand All @@ -341,9 +345,9 @@ def retrieve_logs_for_run_id(self, log: logging.Logger, databricks_run_id: int):
def wait_for_dbfs_logs(
self,
log: logging.Logger,
prefix,
cluster_id,
filename,
prefix: str,
cluster_id: str,
filename: str,
waiter_delay: int = 10,
waiter_max_attempts: int = 10,
) -> Optional[str]:
Expand Down
Loading

0 comments on commit 22acced

Please sign in to comment.