From a59ba3eb94888dce6bbb48c1edf77badd0f358d5 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:22:03 +0100 Subject: [PATCH] refactor databricks auth test --- .../test_databricks_configuration.py | 10 ++++++ .../load/pipeline/test_databricks_pipeline.py | 32 ++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index e27da4db2a..8b3beed2b3 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -4,6 +4,7 @@ pytest.importorskip("databricks") from dlt.common.exceptions import TerminalValueError +from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob from dlt.common.configuration import resolve_configuration @@ -86,3 +87,12 @@ def test_databricks_abfss_converter() -> None: abfss_url == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" ) + + +def test_databricks_auth_invalid() -> None: + with pytest.raises(ConfigurationValueError, match="No valid authentication method detected.*"): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 8a227fd547..86cf91caa1 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -2,6 +2,7 @@ import os from dlt.common.utils import uniq_id +from dlt.destinations import databricks from tests.load.utils import ( GCS_BUCKET, DestinationTestConfiguration, @@ -151,12 +152,10 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon destinations_configs(default_sql_configs=True, subset=("databricks",)), ids=lambda x: x.name, ) -def test_databricks_oauth(destination_config: DestinationTestConfiguration) -> None: - from dlt.destinations import databricks - +def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None: + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" bricks = databricks() config = bricks.configuration(None, accept_partial=True) - assert config.credentials.client_id and config.credentials.client_secret dataset_name = "test_databricks_oauth" + uniq_id() @@ -170,3 +169,28 @@ def test_databricks_oauth(destination_config: DestinationTestConfiguration) -> N with pipeline.sql_client() as client: rows = client.execute_sql(f"select * from {dataset_name}.digits") assert len(rows) == 3 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=("databricks",)), + ids=lambda x: x.name, +) +def test_databricks_auth_token(destination_config: DestinationTestConfiguration) -> None: + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + bricks = databricks() + config = bricks.configuration(None, accept_partial=True) + assert config.credentials.access_token + + dataset_name = "test_databricks_token" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_databricks_token", dataset_name=dataset_name, destination=bricks + ) + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is False + + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3