diff --git a/NEWS.md b/NEWS.md index 9d20db3..b620940 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,6 +7,11 @@ * No longer install 'rpy2' by default. It will prompt user for installation the first time `spark_apply()` is called (#125) +* Adding support for Databricks serverless interactive compute (#127) + +* Extended authentication method support for Databricks by deferring to SDK +(#127) + # pysparklyr 0.1.5 ### Improvements diff --git a/R/databricks-utils.R b/R/databricks-utils.R index 57eb233..eab2206 100644 --- a/R/databricks-utils.R +++ b/R/databricks-utils.R @@ -27,6 +27,10 @@ databricks_host <- function(host = NULL, fail = TRUE) { } databricks_token <- function(token = NULL, fail = FALSE) { + # if token provided, return + # otherwise, search for token: + # DATABRICKS_TOKEN > CONNECT_DATABRICKS_TOKEN > .rs.api.getDatabricksToken + if (!is.null(token)) { return(set_names(token, "argument")) } @@ -53,7 +57,6 @@ databricks_token <- function(token = NULL, fail = FALSE) { paste0( "No authentication token was identified: \n", " - No 'DATABRICKS_TOKEN' environment variable found \n", - " - No Databricks OAuth token found \n", " - Not passed as a function argument" ), "Please add your Token to 'DATABRICKS_TOKEN' inside your .Renviron file." @@ -66,15 +69,13 @@ databricks_token <- function(token = NULL, fail = FALSE) { } databricks_dbr_version_name <- function(cluster_id, - host = NULL, - token = NULL, + client, silent = FALSE) { bullets <- NULL version <- NULL cluster_info <- databricks_dbr_info( cluster_id = cluster_id, - host = host, - token = token, + client = client, silent = silent ) cluster_name <- substr(cluster_info$cluster_name, 1, 100) @@ -96,8 +97,7 @@ databricks_extract_version <- function(x) { } databricks_dbr_info <- function(cluster_id, - host = NULL, - token = NULL, + client, silent = FALSE) { cli_div(theme = cli_colors()) @@ -109,10 +109,10 @@ databricks_dbr_info <- function(cluster_id, ) } - out <- databricks_cluster_get(cluster_id, host, token) + out <- databricks_cluster_get(cluster_id, client) if (inherits(out, "try-error")) { - sanitized <- sanitize_host(host, silent) - out <- databricks_cluster_get(cluster_id, sanitized, token) + # sanitized <- sanitize_host(host, silent) + out <- databricks_cluster_get(cluster_id, client) } if (inherits(out, "try-error")) { @@ -159,30 +159,17 @@ databricks_dbr_info <- function(cluster_id, out } -databricks_dbr_version <- function(cluster_id, - host = NULL, - token = NULL) { +databricks_dbr_version <- function(cluster_id, client) { vn <- databricks_dbr_version_name( cluster_id = cluster_id, - host = host, - token = token + client = client ) vn$version } -databricks_cluster_get <- function(cluster_id, - host = NULL, - token = NULL) { +databricks_cluster_get <- function(cluster_id, client) { try( - paste0( - host, - "/api/2.0/clusters/get" - ) %>% - request() %>% - req_auth_bearer_token(token) %>% - req_body_json(list(cluster_id = cluster_id)) %>% - req_perform() %>% - resp_body_json(), + client$clusters$get(cluster_id = cluster_id)$as_dict(), silent = TRUE ) } @@ -227,25 +214,39 @@ databricks_dbr_error <- function(error) { ) } -sanitize_host <- function(url, silent = FALSE) { - parsed_url <- url_parse(url) - new_url <- url_parse("http://localhost") - if (is.null(parsed_url$scheme)) { - new_url$scheme <- "https" - if (!is.null(parsed_url$path) && is.null(parsed_url$hostname)) { - new_url$hostname <- parsed_url$path - } - } else { - new_url$scheme <- parsed_url$scheme - new_url$hostname <- parsed_url$hostname +# from httr2 +is_hosted_session <- function () { + if (nzchar(Sys.getenv("COLAB_RELEASE_TAG"))) { + return(TRUE) } - ret <- url_build(new_url) - if (ret != url && !silent) { - cli_div(theme = cli_colors()) - cli_alert_warning( - "{.header Changing host URL to:} {.emph {ret}}" + Sys.getenv("RSTUDIO_PROGRAM_MODE") == "server" && + !grepl("localhost", Sys.getenv("RSTUDIO_HTTP_REFERER"), fixed = TRUE) +} + +databricks_desktop_login <- function(host = NULL, profile = NULL) { + + # host takes priority over profile + if (!is.null(host)) { + method <- "--host" + value <- host + } else if (!is.null(profile)) { + method <- "--profile" + value <- profile + } else { + # todo rlang error? + stop("must specifiy `host` or `profile`, neither were set") + } + + cli_path <- Sys.getenv("DATABRICKS_CLI_PATH", "databricks") + if (!is_hosted_session() && nchar(Sys.which(cli_path)) != 0) { + # When on desktop, try using the Databricks CLI for auth. + output <- suppressWarnings( + system2( + cli_path, + c("auth", "login", method, value), + stdout = TRUE, + stderr = TRUE + ) ) - cli_end() } - ret } diff --git a/R/deploy.R b/R/deploy.R index c99a98c..1e53bdc 100644 --- a/R/deploy.R +++ b/R/deploy.R @@ -52,6 +52,7 @@ deploy_databricks <- function( cluster_id <- cluster_id %||% Sys.getenv("DATABRICKS_CLUSTER_ID") + # TODO: this needs to be adjusted to use client, might need to refactor? if (is.null(version) && !is.null(cluster_id)) { version <- databricks_dbr_version( cluster_id = cluster_id, diff --git a/R/python-install.R b/R/python-install.R index 0fafe49..3351488 100644 --- a/R/python-install.R +++ b/R/python-install.R @@ -217,7 +217,8 @@ install_environment <- function( "PyArrow", "grpcio", "google-api-python-client", - "grpcio_status" + "grpcio_status", + "databricks-sdk" ) if (add_torch && install_ml) { diff --git a/R/sparklyr-spark-connect.R b/R/sparklyr-spark-connect.R index 753c0d9..b1e1a3f 100644 --- a/R/sparklyr-spark-connect.R +++ b/R/sparklyr-spark-connect.R @@ -63,32 +63,19 @@ spark_connect_method.spark_method_databricks_connect <- function( ...) { args <- list(...) cluster_id <- args$cluster_id + serverless <- args$serverless %||% FALSE + profile <- args$profile %||% NULL token <- args$token envname <- args$envname host_sanitize <- args$host_sanitize %||% TRUE silent <- args$silent %||% FALSE method <- method[[1]] + token <- databricks_token(token, fail = FALSE) cluster_id <- cluster_id %||% Sys.getenv("DATABRICKS_CLUSTER_ID") - master <- databricks_host(master, fail = FALSE) - if (host_sanitize && master != "") { - master <- sanitize_host(master, silent) - } - - cluster_info <- NULL - if (cluster_id != "" && master != "" && token != "") { - cluster_info <- databricks_dbr_version_name( - cluster_id = cluster_id, - host = master, - token = token, - silent = silent - ) - if (is.null(version)) { - version <- cluster_info$version - } - } + # load python env envname <- use_envname( backend = "databricks", version = version, @@ -102,16 +89,71 @@ spark_connect_method.spark_method_databricks_connect <- function( return(invisible) } - db <- import_check("databricks.connect", envname, silent) + # load python libs + dbc <- import_check("databricks.connect", envname, silent) + db_sdk <- import_check("databricks.sdk", envname, silent = TRUE) + + # SDK behaviour + # https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#default-authentication-flow + + conf_args <- list() + + # the profile as specified - which has a default of 'DEFAULT' + # otherwise, if a token is found, propagate to SDK config + + # TODO: emit messages about connection here? + # specific vars taken priority, profile only works when no env vars are set + if (token != "" && master != "") { + conf_args$host <- master + conf_args$token <- token + conf_args$auth_type <- "pat" + databricks_desktop_login(host = master) + } else if (!is.null(profile)) { + conf_args$profile <- profile + databricks_desktop_login(profile = profile) + } + + # serverless config related settings + if (serverless) { + conf_args$serverless_compute_id <- "auto" + } else { + conf_args$cluster_id <- cluster_id + } + + sdk_config <- db_sdk$core$Config(!!!conf_args) + + # create workspace client + sdk_client <- db_sdk$WorkspaceClient(config = sdk_config) + + # if serverless is TRUE, cluster_id is overruled (set to NULL) + cluster_info <- NULL + if (!serverless) { + if (cluster_id != "" && master != "" && token != "") { + cluster_info <- databricks_dbr_version_name( + cluster_id = cluster_id, + client = sdk_client, + silent = silent + ) + if (is.null(version)) { + version <- cluster_info$version + } + } + } else { + cluster_id <- NULL + } if (!is.null(cluster_info)) { msg <- "{.header Connecting to} {.emph '{cluster_info$name}'}" msg_done <- "{.header Connected to:} {.emph '{cluster_info$name}'}" master_label <- glue("{cluster_info$name} ({cluster_id})") - } else { + } else if (!serverless) { msg <- "{.header Connecting to} {.emph '{cluster_id}'}" msg_done <- "{.header Connected to:} '{.emph '{cluster_id}'}'" master_label <- glue("Databricks Connect - Cluster: {cluster_id}") + } else if (serverless) { + msg <- "{.header Connecting to} {.emph serverless}" + msg_done <- "{.header Connected to:} '{.emph serverless}'" + master_label <- glue("Databricks Connect - Cluster: serverless") } if (!silent) { @@ -119,17 +161,8 @@ spark_connect_method.spark_method_databricks_connect <- function( cli_progress_step(msg, msg_done) } - remote_args <- list() - if (master != "") remote_args$host <- master - if (token != "") remote_args$token <- token - if (cluster_id != "") remote_args$cluster_id <- cluster_id - - databricks_session <- function(...) { - user_agent <- build_user_agent() - db$DatabricksSession$builder$remote(...)$userAgent(user_agent) - } - - conn <- exec(databricks_session, !!!remote_args) + user_agent <- build_user_agent() + conn <- dbc$DatabricksSession$builder$sdkConfig(sdk_config)$userAgent(user_agent) if (!silent) { cli_progress_done() @@ -141,6 +174,7 @@ spark_connect_method.spark_method_databricks_connect <- function( master_label = master_label, con_class = "connect_databricks", cluster_id = cluster_id, + serverless = serverless, method = method, config = config ) @@ -151,6 +185,7 @@ initialize_connection <- function( master_label, con_class, cluster_id = NULL, + serverless = NULL, method = NULL, config = NULL) { warnings <- import("warnings") @@ -173,12 +208,15 @@ initialize_connection <- function( "ignore", message = "Index.format is deprecated and will be removed in a future version" ) + session <- conn$getOrCreate() get_version <- try(session$version, silent = TRUE) if (inherits(get_version, "try-error")) databricks_dbr_error(get_version) - session$conf$set("spark.sql.session.localRelationCacheThreshold", 1048576L) - session$conf$set("spark.sql.execution.arrow.pyspark.enabled", "true") - session$conf$set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") + if (!serverless) { + session$conf$set("spark.sql.session.localRelationCacheThreshold", 1048576L) + session$conf$set("spark.sql.execution.arrow.pyspark.enabled", "true") + session$conf$set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") + } # do we need this `spark_context` object? spark_context <- list(spark_context = session)