diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index c95b6eba4c..21338bd310 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -4,6 +4,7 @@ from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.configuration.exceptions import ConfigurationValueError DATABRICKS_APPLICATION_ID = "dltHub_dlt" @@ -15,6 +16,8 @@ class DatabricksCredentials(CredentialsConfiguration): server_hostname: str = None http_path: str = None access_token: Optional[TSecretStrValue] = None + client_id: Optional[TSecretStrValue] = None + client_secret: Optional[TSecretStrValue] = None http_headers: Optional[Dict[str, str]] = None session_configuration: Optional[Dict[str, Any]] = None """Dict of session parameters that will be passed to `databricks.sql.connect`""" @@ -27,9 +30,18 @@ class DatabricksCredentials(CredentialsConfiguration): "server_hostname", "http_path", "catalog", + "client_id", + "client_secret", "access_token", ] + def on_resolved(self) -> None: + if not ((self.client_id and self.client_secret) or self.access_token): + raise ConfigurationValueError( + "No valid authentication method detected. Provide either 'client_id' and" + " 'client_secret' for OAuth, or 'access_token' for token-based authentication." + ) + def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( catalog=self.catalog, diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 8bff4e0d73..16e1e73d93 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -11,10 +11,12 @@ Tuple, Union, Dict, + cast, + Callable, ) - -from databricks import sql as databricks_lib +from databricks.sdk.core import Config, oauth_service_principal +from databricks import sql as databricks_lib # type: ignore[attr-defined] from databricks.sql.client import ( Connection as DatabricksSqlConnection, Cursor as DatabricksSqlCursor, @@ -73,8 +75,22 @@ def __init__( self._conn: DatabricksSqlConnection = None self.credentials = credentials + def _get_oauth_credentials(self) -> Optional[Callable[[], Dict[str, str]]]: + config = Config( + host=f"https://{self.credentials.server_hostname}", + client_id=self.credentials.client_id, + client_secret=self.credentials.client_secret, + ) + return cast(Callable[[], Dict[str, str]], oauth_service_principal(config)) + def open_connection(self) -> DatabricksSqlConnection: conn_params = self.credentials.to_connector_params() + + if self.credentials.client_id and self.credentials.client_secret: + conn_params["credentials_provider"] = self._get_oauth_credentials + else: + conn_params["access_token"] = self.credentials.access_token + self._conn = databricks_lib.connect( **conn_params, schema=self.dataset_name, use_inline_params="silent" ) diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index 513a3b792f..dd046ce28a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -90,6 +90,29 @@ If you already have your Databricks workspace set up, you can skip to the [Loade Click your email in the top right corner and go to "User Settings". Go to "Developer" -> "Access Tokens". Generate a new token and save it. You will use it in your `dlt` configuration. +## OAuth M2M (Machine-to-Machine) Authentication + +You can authenticate to Databricks using a service principal via OAuth M2M. This method allows for secure, programmatic access to Databricks resources without requiring a user-managed personal access token. + +### Create a Service Principal in Databricks +Follow the instructions in the Databricks documentation to create a service principal and retrieve the client_id and client_secret: + +[Authenticate access to Databricks using OAuth M2M](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html) + +Once you have the service principal credentials, update your secrets.toml as shown bellow. + +### Configuration + +Add the following fields to your `.dlt/secrets.toml` file: +```toml +[destination.databricks.credentials] +server_hostname = "MY_DATABRICKS.azuredatabricks.net" +http_path = "/sql/1.0/warehouses/12345" +catalog = "my_catalog" +client_id = "XXX" +client_secret = "XXX" +``` + ## Loader setup guide **1. Initialize a project with a pipeline that loads to Databricks by running** @@ -118,7 +141,7 @@ Example: [destination.databricks.credentials] server_hostname = "MY_DATABRICKS.azuredatabricks.net" http_path = "/sql/1.0/warehouses/12345" -access_token = "MY_ACCESS_TOKEN" +access_token = "MY_ACCESS_TOKEN" # Replace for client_id and client_secret when using OAuth catalog = "my_catalog" ``` diff --git a/poetry.lock b/poetry.lock index 83090360b0..82d9bf90f8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -2208,6 +2208,26 @@ nr-date = ">=2.0.0,<3.0.0" typeapi = ">=2.0.1,<3.0.0" typing-extensions = ">=3.10.0" +[[package]] +name = "databricks-sdk" +version = "0.39.0" +description = "Databricks SDK for Python (Beta)" +optional = true +python-versions = ">=3.7" +files = [ + {file = "databricks_sdk-0.39.0-py3-none-any.whl", hash = "sha256:915fbf12b249264f74ddae2ca739530e3c4a9c5a454617ac403115d6466c2f99"}, + {file = "databricks_sdk-0.39.0.tar.gz", hash = "sha256:2e04edbb9e050f4362da804fb5dad07637c5adecfcffb4d0ca8abb5aefa36d06"}, +] + +[package.dependencies] +google-auth = ">=2.0,<3.0" +requests = ">=2.28.1,<3" + +[package.extras] +dev = ["autoflake", "databricks-connect", "httpx", "ipython", "ipywidgets", "isort", "langchain-openai", "openai", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist", "requests-mock", "wheel", "yapf"] +notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"] +openai = ["httpx", "langchain-openai", "openai"] + [[package]] name = "databricks-sql-connector" version = "2.9.6" @@ -10680,7 +10700,7 @@ az = ["adlfs"] bigquery = ["db-dtypes", "gcsfs", "google-cloud-bigquery", "grpcio", "pyarrow"] cli = ["cron-descriptor", "pipdeptree"] clickhouse = ["adlfs", "clickhouse-connect", "clickhouse-driver", "gcsfs", "pyarrow", "s3fs"] -databricks = ["databricks-sql-connector"] +databricks = ["databricks-sdk", "databricks-sql-connector"] deltalake = ["deltalake", "pyarrow"] dremio = ["pyarrow"] duckdb = ["duckdb"] @@ -10707,4 +10727,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "84e8b8eccd9b8ee104a2dc08f5b83987aeb06540d61330390ce849cc1ad6acb4" +content-hash = "5513aca05ae04d7941f2a890d0fefa86a08371508a2d319c1e558c29ff8a45f3" diff --git a/pyproject.toml b/pyproject.toml index bfa830cd06..d12073601d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ db-dtypes = { version = ">=1.2.0", optional = true } # pyiceberg = { version = ">=0.7.1", optional = true, extras = ["sql-sqlite"] } # we will rely on manual installation of `sqlalchemy>=2.0.18` instead pyiceberg = { version = ">=0.8.1", python = ">=3.9", optional = true } +databricks-sdk = {version = ">=0.38.0", optional = true} [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] @@ -117,7 +118,7 @@ weaviate = ["weaviate-client"] mssql = ["pyodbc"] synapse = ["pyodbc", "adlfs", "pyarrow"] qdrant = ["qdrant-client"] -databricks = ["databricks-sql-connector"] +databricks = ["databricks-sql-connector", "databricks-sdk"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] lancedb = ["lancedb", "pyarrow", "tantivy"] diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index e27da4db2a..8b3beed2b3 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -4,6 +4,7 @@ pytest.importorskip("databricks") from dlt.common.exceptions import TerminalValueError +from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob from dlt.common.configuration import resolve_configuration @@ -86,3 +87,12 @@ def test_databricks_abfss_converter() -> None: abfss_url == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" ) + + +def test_databricks_auth_invalid() -> None: + with pytest.raises(ConfigurationValueError, match="No valid authentication method detected.*"): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index e802cde693..078dce3a7f 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -2,6 +2,7 @@ import os from dlt.common.utils import uniq_id +from dlt.destinations import databricks from tests.load.utils import ( GCS_BUCKET, DestinationTestConfiguration, @@ -23,6 +24,10 @@ ids=lambda x: x.name, ) def test_databricks_external_location(destination_config: DestinationTestConfiguration) -> None: + # force token-based authentication + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + # do not interfere with state os.environ["RESTORE_FROM_DESTINATION"] = "False" # let the package complete even with failed jobs @@ -145,3 +150,54 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon assert ( "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=("databricks",)), + ids=lambda x: x.name, +) +def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None: + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" + bricks = databricks() + config = bricks.configuration(None, accept_partial=True) + assert config.credentials.client_id and config.credentials.client_secret + assert not config.credentials.access_token + + dataset_name = "test_databricks_oauth" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_databricks_oauth", dataset_name=dataset_name, destination=bricks + ) + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is False + + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=("databricks",)), + ids=lambda x: x.name, +) +def test_databricks_auth_token(destination_config: DestinationTestConfiguration) -> None: + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + bricks = databricks() + config = bricks.configuration(None, accept_partial=True) + assert config.credentials.access_token + assert not (config.credentials.client_secret and config.credentials.client_id) + + dataset_name = "test_databricks_token" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_databricks_token", dataset_name=dataset_name, destination=bricks + ) + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is False + + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3