Skip to content

Commit

Permalink
Add OAuth support in Databricks Loader
Browse files Browse the repository at this point in the history
New config section for storage:

```
"storage" : {
...
    "oauth": {
      "clientId": "client-id"
      "clientSecret": ${OAUTH_CLIENT_SECRET}
    }
}
```

So that we can set `clientId` and `clientSecret` authentication properties. Old-style `password` field (relying on personal access tokens) is still supported.
  • Loading branch information
pondzix committed Aug 22, 2024
1 parent 618565a commit 4f9b068
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
java-version: 11
- name: Get the latest Databricks JDBC driver
run: |
curl https://databricks-bi-artifacts.s3.us-east-2.amazonaws.com/simbaspark-drivers/jdbc/2.6.34/DatabricksJDBC42-2.6.34.1058.zip --output DatabricksJDBC42.jar.zip
curl https://databricks-bi-artifacts.s3.us-east-2.amazonaws.com/simbaspark-drivers/jdbc/2.6.40/DatabricksJDBC42-2.6.40.1070.zip --output DatabricksJDBC42.jar.zip
unzip DatabricksJDBC42.jar.zip
cp ./*/DatabricksJDBC42.jar . # 2.6.34 download changes directory structure - grab the jar from nested directory (which has entropy in its name)
- name: Docker login
Expand Down
4 changes: 4 additions & 0 deletions config/loader/aws/databricks.config.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
"parameterName": "snowplow.databricks.password"
}
},
"oauth": {
"clientId": "client-id"
"clientSecret": ${OAUTH_CLIENT_SECRET}
},
# Optional. Override the Databricks default catalog, e.g. with a Unity catalog name.
"catalog": "hive_metastore",
# DB schema
Expand Down
4 changes: 4 additions & 0 deletions config/loader/azure/databricks.config.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
"parameterName": "snowplow.databricks.password"
}
},
"oauth": {
"clientId": "client-id"
"clientSecret": ${OAUTH_CLIENT_SECRET}
},
# Optional. Override the Databricks default catalog, e.g. with a Unity catalog name.
"catalog": "hive_metastore",
# DB schema
Expand Down
4 changes: 4 additions & 0 deletions config/loader/gcp/databricks.config.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
"parameterName": "snowplow.databricks.password"
}
},
"oauth": {
"clientId": "client-id"
"clientSecret": ${OAUTH_CLIENT_SECRET}
},
# Optional. Override the Databricks default catalog, e.g. with a Unity catalog name.
"catalog": "hive_metastore",
# DB schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class ConfigSpec extends Specification {
val result = getConfigFromResource("/loader/aws/databricks.config.minimal.hocon", testParseConfig)
val storage = ConfigSpec.exampleStorage.copy(
catalog = None,
password = StorageTarget.PasswordConfig.PlainText("Supersecret1")
password = Some(StorageTarget.PasswordConfig.PlainText("Supersecret1")),
oauth = None
)
val cloud = Config.Cloud.AWS(RegionSpec.DefaultTestRegion, exampleMessageQueue.copy(region = Some(RegionSpec.DefaultTestRegion)))
val retries = exampleRetries.copy(cumulativeBound = Some(20.minutes))
Expand Down Expand Up @@ -140,7 +141,8 @@ class ConfigSpec extends Specification {
val result = getConfigFromResource("/loader/gcp/databricks.config.minimal.hocon", testParseConfig)
val storage = ConfigSpec.exampleStorage.copy(
catalog = None,
password = StorageTarget.PasswordConfig.PlainText("Supersecret1")
password = Some(StorageTarget.PasswordConfig.PlainText("Supersecret1")),
oauth = None
)
val retries = exampleRetries.copy(cumulativeBound = Some(20.minutes))
val readyCheck = exampleReadyCheck.copy(strategy = Config.Strategy.Constant, backoff = 15.seconds)
Expand All @@ -167,7 +169,8 @@ class ConfigSpec extends Specification {
val result = getConfigFromResource("/loader/azure/databricks.config.minimal.hocon", testParseConfig)
val storage = ConfigSpec.exampleStorage.copy(
catalog = None,
password = StorageTarget.PasswordConfig.PlainText("Supersecret1")
password = Some(StorageTarget.PasswordConfig.PlainText("Supersecret1")),
oauth = None
)
val retries = exampleRetries.copy(cumulativeBound = Some(20.minutes))
val readyCheck = exampleReadyCheck.copy(strategy = Config.Strategy.Constant, backoff = 15.seconds)
Expand Down Expand Up @@ -200,7 +203,8 @@ object ConfigSpec {
"atomic",
443,
"/databricks/http/path",
StorageTarget.PasswordConfig.EncryptedKey(StorageTarget.EncryptedConfig("snowplow.databricks.password")),
Some(StorageTarget.PasswordConfig.EncryptedKey(StorageTarget.EncryptedConfig("snowplow.databricks.password"))),
Some(StorageTarget.Databricks.OAuth("client-id", "client-secret")),
None,
"snowplow-rdbloader-oss",
StorageTarget.LoadAuthMethod.NoCreds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ object DatabricksSpec {
"snowplow",
443,
"some/path",
StorageTarget.PasswordConfig.PlainText("xxx"),
Some(StorageTarget.PasswordConfig.PlainText("xxx")),
None,
None,
"useragent",
StorageTarget.LoadAuthMethod.NoCreds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,10 @@ object Config {
private def azureVaultCheck(config: Config[StorageTarget]): List[String] =
config.cloud match {
case c: Config.Cloud.Azure if c.azureVaultName.isEmpty =>
(config.storage.password, config.storage.sshTunnel.flatMap(_.bastion.key)) match {
case (_: StorageTarget.PasswordConfig.EncryptedKey, _) | (_, Some(_)) => List("Azure vault name is needed")
case _ => Nil
(config.storage.credentials, config.storage.sshTunnel.flatMap(_.bastion.key)) match {
case (Some(StorageTarget.Credentials(_, _: StorageTarget.PasswordConfig.EncryptedKey)), _) | (_, Some(_)) =>
List("Azure vault name is needed")
case _ => Nil
}
case _ => Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import scala.concurrent.duration.{Duration, FiniteDuration}
*/
sealed trait StorageTarget extends Product with Serializable {
def schema: String
def username: String
def password: StorageTarget.PasswordConfig
def credentials: Option[StorageTarget.Credentials]
def sshTunnel: Option[StorageTarget.TunnelConfig]

def doobieCommitStrategy(timeouts: Config.Timeouts): Strategy = Transaction.defaultStrategy(timeouts)
Expand Down Expand Up @@ -80,6 +79,9 @@ object StorageTarget {
sshTunnel: Option[TunnelConfig],
loadAuthMethod: LoadAuthMethod
) extends StorageTarget {

override def credentials: Option[Credentials] = Some(Credentials(username, password))

override def driver: String = "com.amazon.redshift.jdbc42.Driver"

override def connectionUrl: String = s"jdbc:redshift://$host:$port/$database"
Expand Down Expand Up @@ -107,15 +109,17 @@ object StorageTarget {
schema: String,
port: Int,
httpPath: String,
password: PasswordConfig,
password: Option[PasswordConfig],
oauth: Option[Databricks.OAuth],
sshTunnel: Option[TunnelConfig],
userAgent: String,
loadAuthMethod: LoadAuthMethod,
eventsOptimizePeriod: FiniteDuration,
logLevel: Int
) extends StorageTarget {

override def username: String = "token"
override def credentials: Option[Credentials] =
password.map(configuredPassword => Credentials(username = "token", password = configuredPassword))

override def driver: String = "com.databricks.client.jdbc.Driver"

Expand All @@ -130,18 +134,37 @@ object StorageTarget {
props.put("httpPath", httpPath)
props.put("ssl", 1)
props.put("LogLevel", logLevel)
props.put("AuthMech", 3)
props.put("transportMode", "http")
props.put("UserAgentEntry", userAgent)
setAuthProperties(props)
props
}

private def setAuthProperties(props: Properties) =
oauth match {
case Some(configuredOAuth) =>
props.put("AuthMech", 11)
props.put("Auth_Flow", 1)
props.put("OAuth2ClientId", configuredOAuth.clientId)
props.put("OAuth2Secret", configuredOAuth.clientSecret)
case None =>
// When no OAuth use default, legacy personal access tokens (represented by 'Credentials' class)
props.put("AuthMech", 3)
}

override def eventsLoadAuthMethod: LoadAuthMethod = loadAuthMethod
override def foldersLoadAuthMethod: LoadAuthMethod = loadAuthMethod

override def reportRecoveryTableMetrics: Boolean = false
}

object Databricks {
final case class OAuth(clientId: String, clientSecret: String)

implicit def oauthConfigDecoder: Decoder[OAuth] =
deriveDecoder[OAuth]
}

final case class Snowflake(
snowflakeRegion: Option[String],
username: String,
Expand All @@ -159,6 +182,8 @@ object StorageTarget {
readyCheck: Snowflake.ReadyCheck
) extends StorageTarget {

override def credentials: Option[Credentials] = Some(Credentials(username, password))

override def connectionUrl: String =
host match {
case Right(h) =>
Expand Down Expand Up @@ -352,6 +377,8 @@ object StorageTarget {
/** Destination socket for SSH tunnel - usually DB socket inside private network */
final case class DestinationConfig(host: String, port: Int)

final case class Credentials(username: String, password: PasswordConfig)

/**
* ADT representing fact that password can be either plain-text or encrypted in EC2 Parameter
* Store or GCP Secret Manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ object Environment {
httpClient <- BlazeClientBuilder[F].withExecutionContext(global.compute).resource
implicit0(logger: Logger[F]) = Slf4jLogger.getLogger[F]
iglu <- Iglu.igluInterpreter(httpClient, cli.resolverConfig)
implicit0(logging: Logging[F]) =
Logging.loggingInterpreter[F](List(cli.config.storage.password.getUnencrypted, cli.config.storage.username))
implicit0(logging: Logging[F]) = getLoggingInterpreter[F](cli.config)
implicit0(random: Random[F]) <- Resource.eval(Random.scalaUtilRandom[F])
tracker <- Monitoring.initializeTracking[F](cli.config.monitoring, httpClient)
sentry <- Sentry.init[F](cli.config.monitoring.sentry.map(_.dsn))
Expand Down Expand Up @@ -148,6 +147,14 @@ object Environment {
telemetry
)

private def getLoggingInterpreter[F[_]: Async](config: Config[StorageTarget]): Logging[F] = {
val stopWords = config.storage.credentials match {
case Some(configuredCredentials) => List(configuredCredentials.password.getUnencrypted, configuredCredentials.username)
case None => List.empty
}
Logging.loggingInterpreter[F](stopWords)
}

def createCloudServices[F[_]: Async: Logger: Cache](
config: Config[StorageTarget],
control: Control[F]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ trait Transaction[F[_], C[_]] {
object Transaction {

/** Should be enough for all monitoring and loading */
val PoolSize = 4
private val PoolSize = 4

def apply[F[_], C[_]](implicit ev: Transaction[F, C]): Transaction[F, C] = ev

def configureHikari[F[_]: Sync](target: StorageTarget, ds: HikariConfig): F[Unit] =
private def configureHikari[F[_]: Sync: SecretStore](target: StorageTarget, ds: HikariConfig): F[Unit] =
Sync[F].delay {
ds.setJdbcUrl(target.connectionUrl)
ds.setAutoCommit(target.withAutoCommit)
ds.setMaximumPoolSize(PoolSize)

Expand All @@ -97,27 +98,44 @@ object Transaction {
ds.setMinimumIdle(0)

ds.setDataSourceProperties(target.properties)
}
} *> setJdbcCredentials[F](target, ds)

def buildPool[F[_]: Async: SecretStore: Logging: Sleep](
private def buildPool[F[_]: Async: SecretStore: Logging: Sleep](
target: StorageTarget,
retries: Config.Retries
): Resource[F, Transactor[F]] =
for {
ce <- ExecutionContexts.fixedThreadPool[F](2)
password <- target.password match {
case StorageTarget.PasswordConfig.PlainText(text) =>
Resource.pure[F, String](text)
case StorageTarget.PasswordConfig.EncryptedKey(StorageTarget.EncryptedConfig(parameterName)) =>
Resource.eval(SecretStore[F].getValue(parameterName))
}
xa <- HikariTransactor
.newHikariTransactor[F](target.driver, target.connectionUrl, target.username, password, ce)
_ <- Resource.eval(xa.configure(configureHikari[F](target, _)))
xa <- getTransactor(target)
xa <- Resource.pure(RetryingTransactor.wrap(retries, xa))
xa <- target.sshTunnel.fold(Resource.pure[F, Transactor[F]](xa))(SSH.transactor(_, xa))
} yield xa

private def getTransactor[F[_]: Async: SecretStore](target: StorageTarget): Resource[F, HikariTransactor[F]] =
for {
ec <- ExecutionContexts.fixedThreadPool[F](2)
_ <- Resource.eval(Async[F].delay(Class.forName(target.driver)))
xa <- HikariTransactor.initial[F](ec)
_ <- Resource.eval(xa.configure(configureHikari[F](target, _)))
} yield xa

private def setJdbcCredentials[F[_]: Sync: SecretStore](target: StorageTarget, ds: HikariConfig): F[Unit] =
target.credentials match {
case Some(configuredCredentials) =>
getPassword[F](configuredCredentials).map { password =>
ds.setUsername(configuredCredentials.username)
ds.setPassword(password)
}
case None => Sync[F].unit
}

private def getPassword[F[_]: Sync: SecretStore](credentials: StorageTarget.Credentials): F[String] =
credentials.password match {
case StorageTarget.PasswordConfig.PlainText(text) =>
Sync[F].pure(text)
case StorageTarget.PasswordConfig.EncryptedKey(StorageTarget.EncryptedConfig(parameterName)) =>
SecretStore[F].getValue(parameterName)
}

/**
* Build a necessary (dry-run or real-world) DB interpreter as a `Resource`, which guarantees to
* close a JDBC connection. If connection could not be acquired, it will retry several times
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class CliConfigSpec extends Specification {
val result = CliConfig.parse[IO](cli).value.unsafeRunSync()

result must beRight.like { case CliConfig(config, _, resolverConfig) =>
config.storage.password.getUnencrypted must beEqualTo("Supersecret password from substitution!")
config.storage.credentials.get.password.getUnencrypted must beEqualTo("Supersecret password from substitution!")
resolverConfig must beEqualTo(expectedResolver)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ class StorageTargetSpec extends Specification {
"port": 443,
"httpPath": "http/path",
"password": "Supersecret1",
"oauth": {
"clientId": "client-id",
"clientSecret": "client-secret"
},
"userAgent": "snowplow-rdbloader-oss",
"eventsOptimizePeriod": "2 days",
"loadAuthMethod": {
Expand All @@ -177,7 +181,8 @@ class StorageTargetSpec extends Specification {
schema = "snowplow",
port = 443,
httpPath = "http/path",
password = StorageTarget.PasswordConfig.PlainText("Supersecret1"),
password = Some(StorageTarget.PasswordConfig.PlainText("Supersecret1")),
oauth = Some(StorageTarget.Databricks.OAuth("client-id", "client-secret")),
sshTunnel = None,
userAgent = "snowplow-rdbloader-oss",
loadAuthMethod = StorageTarget.LoadAuthMethod.NoCreds,
Expand Down
8 changes: 7 additions & 1 deletion project/BuildSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,13 @@ object BuildSettings {
Docker / packageName := "rdb-loader-databricks",
initialCommands := "import com.snowplowanalytics.snowplow.loader.databricks._",
Compile / mainClass := Some("com.snowplowanalytics.snowplow.loader.databricks.Main"),
Compile / unmanagedJars += file("DatabricksJDBC42.jar")
Compile / unmanagedJars += file("DatabricksJDBC42.jar"),
// used in extended configuration parsing unit tests
Test / envVars := Map(
"OAUTH_CLIENT_SECRET" -> "client-secret"
),
// envVars works only when fork is enabled
Test / fork := true
) ++ buildSettings ++ addExampleConfToTestCp ++ assemblySettings ++ dynVerSettings

lazy val transformerBatchBuildSettings =
Expand Down

0 comments on commit 4f9b068

Please sign in to comment.