From dc726b90e3b396362737f191e46146ccf235ae00 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 10 Sep 2024 18:23:50 -0400 Subject: [PATCH 01/20] [query] Unify CloudCredentials and Modernise BatchService API --- .../main/scala/is/hail/HailFeatureFlags.scala | 2 +- .../is/hail/backend/local/LocalBackend.scala | 4 +- .../hail/backend/service/ServiceBackend.scala | 159 ++++----- .../is/hail/backend/service/Worker.scala | 13 +- .../is/hail/backend/spark/SparkBackend.scala | 2 +- .../scala/is/hail/io/fs/AzureStorageFS.scala | 64 +--- hail/src/main/scala/is/hail/io/fs/FS.scala | 58 +--- .../scala/is/hail/io/fs/GoogleStorageFS.scala | 117 +++---- .../main/scala/is/hail/io/fs/RouterFS.scala | 49 +++ .../is/hail/io/fs/TerraAzureStorageFS.scala | 25 +- .../scala/is/hail/services/BatchClient.scala | 311 ++++++++++++++++++ .../scala/is/hail/services/BatchConfig.scala | 16 +- .../scala/is/hail/services/Requester.scala | 168 ---------- .../services/batch_client/BatchClient.scala | 263 --------------- .../main/scala/is/hail/services/oauth2.scala | 98 ++++++ .../main/scala/is/hail/services/package.scala | 4 +- .../scala/is/hail/services/requests.scala | 101 ++++++ .../is/hail/backend/ServiceBackendSuite.scala | 145 +++----- .../is/hail/io/fs/AzureStorageFSSuite.scala | 14 +- .../is/hail/io/fs/GoogleStorageFSSuite.scala | 15 +- .../is/hail/services/BatchClientSuite.scala | 40 +++ .../batch_client/BatchClientSuite.scala | 40 --- 22 files changed, 815 insertions(+), 893 deletions(-) create mode 100644 hail/src/main/scala/is/hail/services/BatchClient.scala delete mode 100644 hail/src/main/scala/is/hail/services/Requester.scala delete mode 100644 hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala create mode 100644 hail/src/main/scala/is/hail/services/oauth2.scala create mode 100644 hail/src/main/scala/is/hail/services/requests.scala create mode 100644 hail/src/test/scala/is/hail/services/BatchClientSuite.scala delete mode 100644 hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index b9246d03016..48bb22bb390 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -47,7 +47,7 @@ object HailFeatureFlags { ), ) - def fromMap(m: Map[String, String]): HailFeatureFlags = + def fromEnv(m: Map[String, String] = sys.env): HailFeatureFlags = new HailFeatureFlags( mutable.Map( HailFeatureFlags.defaults.map { diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 31fc698f97c..f2c130ae639 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -66,7 +66,7 @@ object LocalBackend { class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache { - private[this] val flags = HailFeatureFlags.fromMap(sys.env) + private[this] val flags = HailFeatureFlags.fromEnv() private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) def getFlag(name: String): String = flags.get(name) @@ -78,7 +78,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache flags.available // flags can be set after construction from python - def fs: FS = FS.buildRoutes(None, Some(flags), sys.env) + def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = ExecutionTimer.logTime { timer => diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 07b33488c45..0ee85562352 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -5,17 +5,14 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate -import is.hail.expr.ir.{ - Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, - TableIR, TableReader, TypeCheck, -} +import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix -import is.hail.services._ -import is.hail.services.batch_client.BatchClient +import is.hail.services.JobGroupStates.Failure +import is.hail.services.{BatchClient, _} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -25,17 +22,17 @@ import is.hail.variant.ReferenceGenome import scala.annotation.switch import scala.reflect.ClassTag - import java.io._ import java.nio.charset.StandardCharsets import java.util.concurrent._ - import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import sourcecode.Enclosing +import java.nio.file.Path + class ServiceBackendContext( val billingProject: String, val remoteTmpDir: String, @@ -56,16 +53,22 @@ object ServiceBackend { name: String, theHailClassLoader: HailClassLoader, batchClient: BatchClient, - batchId: Option[Long], - jobGroupId: Option[Long], + batchId: Option[Int], + jobGroupId: Option[Int], scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), rpcConfig: ServiceBackendRPCPayload, env: Map[String, String], ): ServiceBackend = { - val flags = HailFeatureFlags.fromMap(rpcConfig.flags) + val flags = HailFeatureFlags.fromEnv(rpcConfig.flags) val shouldProfile = flags.get("profile") != null - val fs = FS.buildRoutes(Some(s"$scratchDir/secrets/gsa-key/key.json"), Some(flags), env) + val fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + flags, + env, + ) + ) val backendContext = new ServiceBackendContext( rpcConfig.billing_project, @@ -113,8 +116,8 @@ class ServiceBackend( var name: String, val theHailClassLoader: HailClassLoader, val batchClient: BatchClient, - val curBatchId: Option[Long], - val curJobGroupId: Option[Long], + val curBatchId: Option[Int], + val curJobGroupId: Option[Int], val flags: HailFeatureFlags, val tmpdir: String, val fs: FS, @@ -178,7 +181,7 @@ class ServiceBackend( val uploadContexts = executor.submit[Unit](() => retryTransientErrors { fs.writePDOS(s"$root/contexts") { os => - var o = 12L * n + var o = 12L * n // 12L = sizeof(Long) + sizeof(Int) collection.foreach { context => val len = context.length os.writeLong(o) @@ -190,88 +193,60 @@ class ServiceBackend( } ) - uploadFunction.get() - uploadContexts.get() - - val parentJobGroup = curJobGroupId.getOrElse(0L) - val jobGroupIdInUpdate = 1 // QoB creates an update for every new stage - val workerJobGroup = JObject( - "job_group_id" -> JInt(jobGroupIdInUpdate), - "absolute_parent_id" -> JInt(parentJobGroup), - "attributes" -> JObject("name" -> JString(stageIdentifier)), + val jobGroup = JobGroupRequest( + job_group_id = 1, // QoB creates an update for every new stage + absolute_parent_id = curJobGroupId.getOrElse(0), + attributes = Map("name" -> stageIdentifier), ) - log.info(s"worker job group spec: $workerJobGroup") - val jobs = collection.zipWithIndex.map { case (_, i) => - var resources = JObject("preemptible" -> JBool(true)) - if (backendContext.workerCores != "None") { - resources = resources.merge(JObject("cpu" -> JString(backendContext.workerCores))) - } - if (backendContext.workerMemory != "None") { - resources = resources.merge(JObject("memory" -> JString(backendContext.workerMemory))) - } - if (backendContext.storageRequirement != "0Gi") { - resources = - resources.merge(JObject("storage" -> JString(backendContext.storageRequirement))) - } - JObject( - "always_run" -> JBool(false), - "job_id" -> JInt(i + 1), - "in_update_parent_ids" -> JArray(List()), - "in_update_job_group_id" -> JInt(jobGroupIdInUpdate), - "process" -> JObject( - "jar_spec" -> JObject( - "type" -> JString("jar_url"), - "value" -> JString(jarLocation), - ), - "command" -> JArray(List( - JString(Main.WORKER), - JString(root), - JString(s"$i"), - JString(s"$n"), - )), - "type" -> JString("jvm"), - "profile" -> JBool(backendContext.profile), - ), - "attributes" -> JObject( - "name" -> JString(s"${name}_stage${stageCount}_${stageIdentifier}_job$i") + log.info(s"worker job group spec: $jobGroup") + + val jobs = collection.indices.map { i => + JobRequest( + job_id = i + 1, + always_run = false, + in_update_job_group_id = jobGroup.job_group_id, + in_update_parent_ids = Array(), + process = JvmJob( + command = Array(Main.WORKER, root, s"${jobGroup.job_group_id}", s"$n"), + jar_url = jarLocation, + profile = flags.get("profile") != null, ), - "resources" -> resources, - "regions" -> JArray(backendContext.regions.map(JString).toList), - "cloudfuse" -> JArray(backendContext.cloudfuseConfig.map { config => - JObject( - "bucket" -> JString(config.bucket), - "mount_path" -> JString(config.mount_path), - "read_only" -> JBool(config.read_only), + resources = Some( + JobResources( + preemptible = true, + cpu = Some(backendContext.workerCores).filter(_ != "None"), + memory = Some(backendContext.workerMemory).filter(_ != "None"), + storage = Some(backendContext.storageRequirement).filter(_ != "0Gi"), ) - }.toList), + ), + regions = Some(backendContext.regions).filter(_.nonEmpty), + cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), + attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"), ) } + uploadFunction.get() + uploadContexts.get() + log.info(s"parallelizeAndComputeWithIndex: $token: running job") - val (batchId, (updateId, jobGroupId)) = curBatchId match { - case Some(id) => - (id, batchClient.update(id, token, workerJobGroup, jobs)) - case None => - val batchId = batchClient.create( - JObject( - "billing_project" -> JString(backendContext.billingProject), - "n_jobs" -> JInt(n), - "token" -> JString(token), - "attributes" -> JObject("name" -> JString(name + "_" + stageCount)), - ), - jobs, + val batchId = curBatchId.getOrElse { + batchClient.newBatch( + BatchRequest( + billing_project = backendContext.billingProject, + n_jobs = 0, + token = token, + attributes = Map("name" -> name), ) - (batchId, (1L, 1L)) + ) } - val batch = batchClient.waitForJobGroup(batchId, jobGroupId) + val (updateId, jobGroupId) = batchClient.newJobGroup(batchId, token, jobGroup, jobs) + val response = batchClient.waitForJobGroup(batchId, jobGroupId) stageCount += 1 - implicit val formats: Formats = DefaultFormats - val batchState = (batch \ "state").extract[String] - if (batchState == "failed") { + if (response.state == Failure) { throw new HailBatchFailure(s"Update $updateId for batch $batchId failed") } @@ -328,8 +303,10 @@ class ServiceBackend( r } - def stop(): Unit = + def stop(): Unit = { executor.shutdownNow() + batchClient.close() + } override def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] = ctx.time { @@ -438,17 +415,23 @@ object ServiceBackendAPI { val inputURL = argv(5) val outputURL = argv(6) - val fs = FS.buildRoutes(Some(s"$scratchDir/secrets/gsa-key/key.json"), None, sys.env) + val fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) + ) val deployConfig = DeployConfig.fromConfigFile( s"$scratchDir/secrets/deploy-config/deploy-config.json" ) DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - val batchClient = new BatchClient(s"$scratchDir/secrets/gsa-key/key.json") + val batchClient = BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")) log.info("BatchClient allocated.") - val batchConfig = BatchConfig.fromConfigFile(s"$scratchDir/batch-config/batch-config.json") + val batchConfig = + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) val batchId = batchConfig.map(_.batchId) val jobGroupId = batchConfig.map(_.jobGroupId) log.info("BatchConfig parsed.") @@ -518,8 +501,6 @@ private class HailSocketAPIOutputStream( } } -case class CloudfuseConfig(bucket: String, mount_path: String, read_only: Boolean) - case class SequenceConfig(fasta: String, index: String) case class ServiceBackendRPCPayload( diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 95067e2d1fb..3dcfa2a63b4 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -1,6 +1,6 @@ package is.hail.backend.service -import is.hail.{HAIL_REVISION, HailContext} +import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags} import is.hail.asm4s._ import is.hail.backend.HailTaskContext import is.hail.io.fs._ @@ -11,14 +11,14 @@ import scala.collection.mutable import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.control.NonFatal - import java.io._ import java.nio.charset._ import java.util import java.util.{concurrent => javaConcurrent} - import org.apache.log4j.Logger +import java.nio.file.Path + class ServiceTaskContext(val partitionId: Int) extends HailTaskContext { override def stageId(): Int = 0 @@ -125,7 +125,12 @@ object Worker { timer.start(s"Job $i/$n") timer.start("readInputs") - val fs = FS.buildRoutes(Some(s"$scratchDir/secrets/gsa-key/key.json"), None, sys.env) + val fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) + ) def open(x: String): SeekableDataInputStream = fs.openNoCompression(x) diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index dfa71427bce..b6ded3487f3 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -332,7 +332,7 @@ class SparkBackend( val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() - private[this] val flags = HailFeatureFlags.fromMap(sys.env) + private[this] val flags = HailFeatureFlags.fromEnv() def getFlag(name: String): String = flags.get(name) diff --git a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala index 61444d1cdb7..a4d99c63fe2 100644 --- a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala @@ -1,31 +1,25 @@ package is.hail.io.fs import is.hail.io.fs.FSUtil.dropTrailingSlash +import is.hail.services.oauth2.AzureCloudCredentials import is.hail.services.retryTransientErrors import is.hail.shadedazure.com.azure.core.credential.AzureSasCredential import is.hail.shadedazure.com.azure.core.util.HttpClientOptions -import is.hail.shadedazure.com.azure.identity.{ - ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, -} -import is.hail.shadedazure.com.azure.storage.blob.{ - BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder, -} import is.hail.shadedazure.com.azure.storage.blob.models.{ BlobItem, BlobRange, BlobStorageException, ListBlobsOptions, } import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient -import is.hail.utils._ - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import is.hail.shadedazure.com.azure.storage.blob.{ + BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder, +} import java.io.{ByteArrayOutputStream, FileNotFoundException, OutputStream} import java.nio.file.Paths import java.time.Duration - -import org.json4s.Formats -import org.json4s.jackson.JsonMethods +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import java.nio.file.Path class AzureStorageFSURL( val account: String, @@ -54,17 +48,13 @@ class AzureStorageFSURL( prefix + pathPart } - override def toString(): String = { + override def toString: String = { val sasTokenPart = sasToken.getOrElse("") this.base + sasTokenPart } } object AzureStorageFS { - object EnvVars { - val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS" - } - private val AZURE_HTTPS_URI_REGEX = "^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r @@ -120,48 +110,30 @@ object AzureStorageFileListEntry { new BlobStorageFileListEntry(url.toString, null, 0, true) } -class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { +case class AzureStorageFSConfig(credentials_file: Option[Path]) + +class AzureStorageFS(val credential: AzureCloudCredentials) extends FS { type URL = AzureStorageFSURL private[this] lazy val clients = mutable.Map[(String, String, Option[String]), BlobServiceClient]() - private lazy val credential = credentialsJSON match { - case None => - new DefaultAzureCredentialBuilder().build() - case Some(keyData) => - implicit val formats: Formats = defaultJSONFormats - val kvs = JsonMethods.parse(keyData) - val appId = (kvs \ "appId").extract[String] - val password = (kvs \ "password").extract[String] - val tenant = (kvs \ "tenant").extract[String] - - new ClientSecretCredentialBuilder() - .clientId(appId) - .clientSecret(password) - .tenantId(tenant) - .build() - } - def getServiceClient(url: URL): BlobServiceClient = { val k = (url.account, url.container, url.sasToken) - - clients.get(k) match { - case Some(client) => client - case None => + clients.getOrElseUpdate( + k, { val clientBuilder = url.sasToken match { case Some(sasToken) => new BlobServiceClientBuilder().credential(new AzureSasCredential(sasToken)) - case None => new BlobServiceClientBuilder().credential(credential) + case None => new BlobServiceClientBuilder().credential(credential.value) } - val blobServiceClient = clientBuilder + clientBuilder .endpoint(s"https://${url.account}.blob.core.windows.net") .clientOptions(httpClientOptions) .buildClient() - clients += (k -> blobServiceClient) - blobServiceClient - } + }, + ) } def setPublicAccessServiceClient(url: AzureStorageFSURL): Unit = { diff --git a/hail/src/main/scala/is/hail/io/fs/FS.scala b/hail/src/main/scala/is/hail/io/fs/FS.scala index cd0489ef9e9..441bde2cdfd 100644 --- a/hail/src/main/scala/is/hail/io/fs/FS.scala +++ b/hail/src/main/scala/is/hail/io/fs/FS.scala @@ -1,27 +1,22 @@ package is.hail.io.fs -import is.hail.{HailContext, HailFeatureFlags} + +import is.hail.HailContext import is.hail.backend.BroadcastValue import is.hail.io.compress.{BGzipInputStream, BGzipOutputStream} -import is.hail.io.fs.AzureStorageFS.EnvVars.AzureApplicationCredentials import is.hail.io.fs.FSUtil.{containsWildcard, dropTrailingSlash} -import is.hail.io.fs.GoogleStorageFS.EnvVars.GoogleApplicationCredentials import is.hail.services._ import is.hail.utils._ import scala.collection.mutable import scala.io.Source - import java.io._ import java.nio.ByteBuffer -import java.nio.charset._ import java.nio.file.FileSystems import java.util.zip.GZIPOutputStream - import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream import org.apache.commons.io.IOUtils import org.apache.hadoop -import org.apache.log4j.Logger class WrappedSeekableDataInputStream(is: SeekableInputStream) extends DataInputStream(is) with Seekable { @@ -258,56 +253,9 @@ abstract class FSPositionedOutputStream(val capacity: Int) extends OutputStream def getPosition: Long = pos } -object FS { - def buildRoutes( - credentialsPath: Option[String], - flags: Option[HailFeatureFlags], - env: Map[String, String], - ): FS = - retryTransientErrors { - - def readString(path: String): String = - using(new FileInputStream(path))(is => IOUtils.toString(is, Charset.defaultCharset())) - - def gcs = new GoogleStorageFS( - credentialsPath.orElse(sys.env.get(GoogleApplicationCredentials)).map(readString), - flags.flatMap(RequesterPaysConfig.fromFlags), - ) - - def az = env.get("HAIL_TERRA") match { - case Some(_) => new TerraAzureStorageFS() - case None => new AzureStorageFS( - credentialsPath.orElse(sys.env.get(AzureApplicationCredentials)).map(readString) - ) - } - - val cloudSpecificFSs = env.get("HAIL_CLOUD") match { - case Some("gcp") => FastSeq(gcs) - case Some("azure") => FastSeq(az) - case Some(cloud) => - throw new IllegalArgumentException(s"Unknown cloud provider: '$cloud'.'") - case None => - if (credentialsPath.isEmpty) FastSeq(gcs, az) - else fatal( - "Don't know to which cloud credentials belong because 'HAIL_CLOUD' was not set." - ) - } - - new RouterFS( - cloudSpecificFSs :+ new HadoopFS( - new SerializableHadoopConfiguration(new hadoop.conf.Configuration()) - ) - ) - } - - private val log = Logger.getLogger(getClass.getName()) -} - -trait FS extends Serializable { +trait FS extends Serializable with Logging { type URL <: FSURL - import FS.log - def parseUrl(filename: String): URL def validUrl(filename: String): Boolean diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index a2aa439ebb7..9ee4c1cad1e 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -1,25 +1,21 @@ package is.hail.io.fs +import com.google.api.client.googleapis.json.GoogleJsonResponseException +import com.google.cloud.http.HttpTransportOptions +import com.google.cloud.storage.Storage.{BlobGetOption, BlobListOption, BlobSourceOption, BlobWriteOption} +import com.google.cloud.storage.{Option => _, _} +import com.google.cloud.{ReadChannel, WriteChannel} import is.hail.HailFeatureFlags import is.hail.io.fs.FSUtil.dropTrailingSlash +import is.hail.io.fs.GoogleStorageFS.RequesterPaysFailure +import is.hail.services.oauth2.GoogleCloudCredentials import is.hail.services.{isTransientError, retryTransientErrors} import is.hail.utils._ -import scala.jdk.CollectionConverters._ - -import java.io.{ByteArrayInputStream, FileNotFoundException, IOException} +import java.io.{FileNotFoundException, IOException} import java.nio.ByteBuffer -import java.nio.file.Paths - -import com.google.api.client.googleapis.json.GoogleJsonResponseException -import com.google.auth.oauth2.ServiceAccountCredentials -import com.google.cloud.{ReadChannel, WriteChannel} -import com.google.cloud.http.HttpTransportOptions -import com.google.cloud.storage.{Blob, BlobId, BlobInfo, Storage, StorageException, StorageOptions} -import com.google.cloud.storage.Storage.{ - BlobGetOption, BlobListOption, BlobSourceOption, BlobWriteOption, -} -import org.apache.log4j.Logger +import java.nio.file.{Path, Paths} +import scala.jdk.CollectionConverters._ case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { def addPathComponent(c: String): GoogleStorageFSURL = @@ -41,11 +37,7 @@ case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { } object GoogleStorageFS { - object EnvVars { - val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS" - } - private val log = Logger.getLogger(getClass.getName()) private[this] val GCS_URI_REGEX = "^gs:\\/\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r def parseUrl(filename: String): GoogleStorageFSURL = { @@ -65,6 +57,32 @@ object GoogleStorageFS { ) } } + + object RequesterPaysFailure { + def unapply(t: Throwable): Option[Throwable] = + Some(t).filter { + case e: IOException => + Option(e.getCause).exists { + case RequesterPaysFailure(_) => true + case _ => false + } + + case exc: StorageException => + Option(exc.getMessage).exists { message => + message == "userProjectMissing" || + (exc.getCode == 400 && message.contains("requester pays")) + } + + case exc: GoogleJsonResponseException => + Option(exc.getMessage).exists { message => + message == "userProjectMissing" || + (exc.getStatusCode == 400 && message.contains("requester pays")) + } + + case _ => + false + } + } } object GoogleStorageFileListEntry { @@ -86,6 +104,8 @@ object GoogleStorageFileListEntry { new BlobStorageFileListEntry(url.toString, null, 0, true) } +case class RequesterPaysConfig(project: String, buckets: Option[Set[String]]) + object RequesterPaysConfig { object Flags { val RequesterPaysProject = "gcs_requester_pays_project" @@ -107,17 +127,17 @@ object RequesterPaysConfig { } } -case class RequesterPaysConfig(project: String, buckets: Option[Set[String]] = None) - extends Serializable +case class GoogleStorageFSConfig( + credentials_file: Option[Path], + requester_pays_config: Option[RequesterPaysConfig], +) class GoogleStorageFS( - private[this] val serviceAccountKey: Option[String] = None, - private[this] var requesterPaysConfig: Option[RequesterPaysConfig] = None, + private[this] val credentials: GoogleCloudCredentials, + private[this] var requesterPaysConfig: Option[RequesterPaysConfig], ) extends FS { type URL = GoogleStorageFSURL - import GoogleStorageFS.log - override def parseUrl(filename: String): URL = GoogleStorageFS.parseUrl(filename) override def validUrl(filename: String): Boolean = @@ -140,32 +160,6 @@ class GoogleStorageFS( Seq() } - object RequesterPaysFailure { - def unapply(t: Throwable): Option[Throwable] = - Some(t).filter { - case e: IOException => - Option(e.getCause).exists { - case RequesterPaysFailure(_) => true - case _ => false - } - - case exc: StorageException => - Option(exc.getMessage).exists { message => - message == "userProjectMissing" || - (exc.getCode == 400 && message.contains("requester pays")) - } - - case exc: GoogleJsonResponseException => - Option(exc.getMessage).exists { message => - message == "userProjectMissing" || - (exc.getStatusCode == 400 && message.contains("requester pays")) - } - - case _ => - false - } - } - private[this] def handleRequesterPays[T, U]( makeRequest: Seq[U] => T, makeUserProjectOption: String => U, @@ -185,23 +179,12 @@ class GoogleStorageFS( .setConnectTimeout(5000) .setReadTimeout(5000) .build() - serviceAccountKey match { - case None => - log.info("Initializing google storage client from latent credentials") - StorageOptions.newBuilder() - .setTransportOptions(transportOptions) - .build() - .getService - case Some(keyData) => - log.info("Initializing google storage client from service account key") - StorageOptions.newBuilder() - .setCredentials( - ServiceAccountCredentials.fromStream(new ByteArrayInputStream(keyData.getBytes)) - ) - .setTransportOptions(transportOptions) - .build() - .getService - } + + StorageOptions.newBuilder() + .setTransportOptions(transportOptions) + .setCredentials(credentials.value) + .build() + .getService } def openNoCompression(url: URL): SeekableDataInputStream = retryTransientErrors { diff --git a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala index 7d8e9df3c52..b61bad0d74b 100644 --- a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala @@ -1,5 +1,13 @@ package is.hail.io.fs +import is.hail.HailFeatureFlags +import is.hail.services.oauth2.{AzureCloudCredentials, GoogleCloudCredentials} +import is.hail.utils.{FastSeq, SerializableHadoopConfiguration} +import org.apache.hadoop.conf.Configuration + +import java.io.Serializable +import java.nio.file.Path + object RouterFSURL { def apply(fs: FS)(_url: fs.URL): RouterFSURL = RouterFSURL(_url, fs) } @@ -15,6 +23,47 @@ case class RouterFSURL private (_url: FSURL, val fs: FS) extends FSURL { override def toString(): String = url.toString } +case class CloudStorageFSConfig( + azure: Option[AzureStorageFSConfig] = None, + google: Option[GoogleStorageFSConfig] = None, +) extends Serializable + +object CloudStorageFSConfig { + def fromFlagsAndEnv( + credentialsFile: Option[Path], + flags: HailFeatureFlags, + env: Map[String, String] = sys.env, + ): CloudStorageFSConfig = { + env.get("HAIL_CLOUD") match { + case Some("azure") => + CloudStorageFSConfig(azure = Some(AzureStorageFSConfig(credentialsFile))) + case Some("gcp") | None => + val rpConf = RequesterPaysConfig.fromFlags(flags) + CloudStorageFSConfig(google = Some(GoogleStorageFSConfig(credentialsFile, rpConf))) + case _ => + CloudStorageFSConfig() + } + } +} + +object RouterFS { + + def buildRoutes(cloudConfig: CloudStorageFSConfig, env: Map[String, String] = sys.env): FS = + new RouterFS( + IndexedSeq.concat( + cloudConfig.google.map { case GoogleStorageFSConfig(path, maybeRequesterPaysConfig) => + new GoogleStorageFS(GoogleCloudCredentials(path), maybeRequesterPaysConfig) + }, + cloudConfig.azure.map { case AzureStorageFSConfig(path) => + val cred = AzureCloudCredentials(path) + if (env.contains("HAIL_TERRA")) new TerraAzureStorageFS(cred) + else new AzureStorageFS(cred) + }, + FastSeq(new HadoopFS(new SerializableHadoopConfiguration(new Configuration()))), + ) + ) +} + class RouterFS(fss: IndexedSeq[FS]) extends FS { type URL = RouterFSURL diff --git a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala index 4078bf30603..98018a8c389 100644 --- a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala @@ -1,29 +1,23 @@ package is.hail.io.fs -import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext -import is.hail.shadedazure.com.azure.identity.{ - DefaultAzureCredential, DefaultAzureCredentialBuilder, -} +import is.hail.services.oauth2.AzureCloudCredentials import is.hail.shadedazure.com.azure.storage.blob.BlobServiceClient import is.hail.utils._ - -import scala.collection.mutable - import org.apache.http.client.methods.HttpPost import org.apache.http.client.utils.URIBuilder import org.apache.http.impl.client.HttpClients import org.apache.http.util.EntityUtils -import org.apache.log4j.Logger -import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods +import org.json4s.{DefaultFormats, Formats} + +import scala.collection.mutable object TerraAzureStorageFS { - private val log = Logger.getLogger(getClass.getName) private val TEN_MINUTES_IN_MS = 10 * 60 * 1000 } -class TerraAzureStorageFS extends AzureStorageFS() { - import TerraAzureStorageFS.{log, TEN_MINUTES_IN_MS} +class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorageFS(credential) { + import TerraAzureStorageFS.TEN_MINUTES_IN_MS private[this] val httpClient = HttpClients.custom().build() private[this] val sasTokenCache = mutable.Map[String, (URL, Long)]() @@ -33,8 +27,6 @@ class TerraAzureStorageFS extends AzureStorageFS() { private[this] val containerResourceId = sys.env("WORKSPACE_STORAGE_CONTAINER_ID") private[this] val storageContainerUrl = parseUrl(sys.env("WORKSPACE_STORAGE_CONTAINER_URL")) - private[this] val credential: DefaultAzureCredential = new DefaultAzureCredentialBuilder().build() - override def getServiceClient(url: URL): BlobServiceClient = if (blobInWorkspaceStorageContainer(url)) { super.getServiceClient(getTerraSasToken(url)) @@ -59,13 +51,10 @@ class TerraAzureStorageFS extends AzureStorageFS() { private def createTerraSasToken(): (URL, Long) = { implicit val formats: Formats = DefaultFormats - val context = new TokenRequestContext() - context.addScopes("https://management.azure.com/.default") - val token = credential.getToken(context).block().getToken() - val url = s"$workspaceManagerUrl/api/workspaces/v1/$workspaceId/resources/controlled/azure/storageContainer/$containerResourceId/getSasToken" val req = new HttpPost(url) + val token = credential.accessToken(FastSeq("https://management.azure.com/.default")) req.addHeader("Authorization", s"Bearer $token") val tenHoursInSeconds = 10 * 3600 diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala new file mode 100644 index 00000000000..780ee1c4819 --- /dev/null +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -0,0 +1,311 @@ +package is.hail.services + +import is.hail.expr.ir.ByteArrayBuilder +import is.hail.services.requests.{BatchServiceRequester, Requester} +import is.hail.utils._ +import org.apache.http.entity.ByteArrayEntity +import org.apache.http.entity.ContentType.APPLICATION_JSON +import org.json4s.JsonAST.{JArray, JBool} +import org.json4s.jackson.JsonMethods +import org.json4s.{CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JObject, JString} + +import java.nio.charset.StandardCharsets +import java.nio.file.Path +import scala.util.Random + +case class BatchRequest( + billing_project: String, + n_jobs: Int, + token: String, + attributes: Map[String, String] = Map.empty, +) + +case class JobGroupRequest( + job_group_id: Int, + absolute_parent_id: Int, + attributes: Map[String, String] = Map.empty, +) + +case class JobRequest( + job_id: Int, + always_run: Boolean, + in_update_job_group_id: Int, + in_update_parent_ids: Array[Int], + process: JobProcess, + resources: Option[JobResources] = None, + regions: Option[Array[String]] = None, + cloudfuse: Option[Array[CloudfuseConfig]] = None, + attributes: Map[String, String] = Map.empty, +) + +sealed trait JobProcess + +case class BashJob( + image: String, + command: Array[String], +) extends JobProcess + +case class JvmJob( + command: Array[String], + jar_url: String, + profile: Boolean, +) extends JobProcess + +case class JobResources( + preemptible: Boolean, + cpu: Option[String], + memory: Option[String], + storage: Option[String], +) + +case class CloudfuseConfig( + bucket: String, + mount_path: String, + read_only: Boolean, +) + +case class JobGroupResponse( + batch_id: Int, + job_group_id: Int, + state: JobGroupState, + complete: Boolean, + n_jobs: Int, + n_completed: Int, + n_succeeded: Int, + n_failed: Int, + n_cancelled: Int, +) + +sealed trait JobGroupState extends Product with Serializable + +object JobGroupStates { + case object Failure extends JobGroupState + case object Cancelled extends JobGroupState + case object Success extends JobGroupState + case object Running extends JobGroupState +} + +object BatchClient { + def apply(deployConfig: DeployConfig, credentialsFile: Path): BatchClient = + new BatchClient(BatchServiceRequester(deployConfig, credentialsFile)) +} + +case class BatchClient private (req: Requester) extends Logging with AutoCloseable { + + implicit private[this] val fmts: Formats = + DefaultFormats + + JobProcessRequestSerializer + + JobGroupStateDeserializer + + JobGroupResponseDeserializer + + def newBatch(createRequest: BatchRequest): Int = { + val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest)) + val batchId = (response \ "id").extract[Int] + log.info(s"run: created batch $batchId") + batchId + } + + def newJobGroup( + batchId: Int, + token: String, + jobGroup: JobGroupRequest, + jobs: IndexedSeq[JobRequest], + ): (Int, Int) = { + + val updateJson = JObject( + "n_jobs" -> JInt(jobs.length), + "n_job_groups" -> JInt(1), + "token" -> JString(token), + ) + + val jobGroupSpec = getJsonBytes(jobGroup) + val jobBunches = createBunches(jobs) + val updateIDAndJobGroupId = + if (jobBunches.length == 1 && jobBunches(0).length + jobGroupSpec.length < 1024 * 1024) { + val b = new ByteArrayBuilder() + b ++= "{\"job_groups\":".getBytes(StandardCharsets.UTF_8) + addBunchBytes(b, Array(jobGroupSpec)) + b ++= ",\"bunch\":".getBytes(StandardCharsets.UTF_8) + addBunchBytes(b, jobBunches(0)) + b ++= ",\"update\":".getBytes(StandardCharsets.UTF_8) + b ++= JsonMethods.compact(updateJson).getBytes(StandardCharsets.UTF_8) + b += '}' + val data = b.result() + val resp = req.post( + s"/api/v1alpha/batches/$batchId/update-fast", + new ByteArrayEntity(data, APPLICATION_JSON), + ) + b.clear() + ((resp \ "update_id").extract[Int], (resp \ "start_job_group_id").extract[Int]) + } else { + val resp = req.post(s"/api/v1alpha/batches/$batchId/updates/create", updateJson) + val updateID = (resp \ "update_id").extract[Int] + val startJobGroupId = (resp \ "start_job_group_id").extract[Int] + + val b = new ByteArrayBuilder() + b ++= "[".getBytes(StandardCharsets.UTF_8) + b ++= jobGroupSpec + b ++= "]".getBytes(StandardCharsets.UTF_8) + req.post( + s"/api/v1alpha/batches/$batchId/updates/$updateID/job-groups/create", + new ByteArrayEntity(b.result(), APPLICATION_JSON), + ) + + b.clear() + var i = 0 + while (i < jobBunches.length) { + addBunchBytes(b, jobBunches(i)) + val data = b.result() + req.post( + s"/api/v1alpha/batches/$batchId/updates/$updateID/jobs/create", + new ByteArrayEntity(data, APPLICATION_JSON), + ) + b.clear() + i += 1 + } + + req.patch(s"/api/v1alpha/batches/$b/updates/$updateID/commit") + (updateID, startJobGroupId) + } + + log.info(s"run: created update $updateIDAndJobGroupId for batch $batchId") + updateIDAndJobGroupId + } + + def run(batchRequest: BatchRequest, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) + : JobGroupResponse = { + val batchID = newBatch(batchRequest) + val (_, jobGroupId) = newJobGroup(batchID, batchRequest.token, jobGroup, jobs) + waitForJobGroup(batchID, jobGroupId) + } + + def waitForJobGroup(batchID: Int, jobGroupId: Int): JobGroupResponse = { + + Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms + + val start = System.nanoTime() + + while (true) { + val jobGroup = req + .get(s"/api/v1alpha/batches/$batchID/job-groups/$jobGroupId") + .extract[JobGroupResponse] + + if (jobGroup.complete) + return jobGroup + + // wait 10% of duration so far + // at least, 50ms + // at most, 5s + val now = System.nanoTime() + val elapsed = now - start + val d = math.max( + math.min( + (0.1 * (0.8 + Random.nextFloat() * 0.4) * (elapsed / 1000.0 / 1000)).toInt, + 5000, + ), + 50, + ) + Thread.sleep(d) + } + + throw new AssertionError("unreachable") + } + + private def createBunches(jobs: IndexedSeq[JobRequest]): BoxedArrayBuilder[Array[Array[Byte]]] = { + val bunches = new BoxedArrayBuilder[Array[Array[Byte]]]() + val bunchb = new BoxedArrayBuilder[Array[Byte]]() + + var i = 0 + var size = 0 + while (i < jobs.length) { + val jobBytes = getJsonBytes(jobs(i)) + if (size + jobBytes.length > 1024 * 1024) { + bunches += bunchb.result() + bunchb.clear() + size = 0 + } + bunchb += jobBytes + size += jobBytes.length + i += 1 + } + assert(bunchb.size > 0) + + bunches += bunchb.result() + bunchb.clear() + bunches + } + + private def getJsonBytes(obj: Any): Array[Byte] = + JsonMethods.compact(Extraction.decompose(obj)).getBytes(StandardCharsets.UTF_8) + + private def addBunchBytes(b: ByteArrayBuilder, bunch: Array[Array[Byte]]): Unit = { + var j = 0 + b += '[' + while (j < bunch.length) { + if (j > 0) + b += ',' + b ++= bunch(j) + j += 1 + } + b += ']' + } + + override def close(): Unit = + req.close() + + private[this] object JobProcessRequestSerializer + extends CustomSerializer[JobProcess](_ => + ( + PartialFunction.empty, + { + case BashJob(image, command) => + JObject( + "type" -> JString("docker"), + "image" -> JString(image), + "command" -> JArray(command.map(JString).toList), + ) + case JvmJob(command, url, profile) => + JObject( + "type" -> JString("jvm"), + "command" -> JArray(command.map(JString).toList), + "jar_spec" -> JObject("type" -> JString("jar_url"), "value" -> JString(url)), + "profile" -> JBool(profile), + ) + }, + ) + ) + + private[this] object JobGroupStateDeserializer + extends CustomSerializer[JobGroupState](_ => + ( + { + case JString("failure") => JobGroupStates.Failure + case JString("cancelled") => JobGroupStates.Cancelled + case JString("success") => JobGroupStates.Success + case JString("running") => JobGroupStates.Running + }, + PartialFunction.empty, + ) + ) + + private[this] object JobGroupResponseDeserializer + extends CustomSerializer[JobGroupResponse](implicit fmts => + ( + { + case o: JObject => + JobGroupResponse( + batch_id = (o \ "batch_id").extract[Int], + job_group_id = (o \ "job_group_id").extract[Int], + state = (o \ "state").extract[JobGroupState], + complete = (o \ "complete").extract[Boolean], + n_jobs = (o \ "n_jobs").extract[Int], + n_completed = (o \ "n_completed").extract[Int], + n_succeeded = (o \ "n_succeeded").extract[Int], + n_failed = (o \ "n_failed").extract[Int], + n_cancelled = (o \ "n_failed").extract[Int], + ) + }, + PartialFunction.empty, + ) + ) +} diff --git a/hail/src/main/scala/is/hail/services/BatchConfig.scala b/hail/src/main/scala/is/hail/services/BatchConfig.scala index 9d2e0256ff5..3cd8e3b0c62 100644 --- a/hail/src/main/scala/is/hail/services/BatchConfig.scala +++ b/hail/src/main/scala/is/hail/services/BatchConfig.scala @@ -1,19 +1,15 @@ package is.hail.services import is.hail.utils._ - -import java.io.{File, FileInputStream} - import org.json4s._ import org.json4s.jackson.JsonMethods +import java.nio.file.{Files, Path} + object BatchConfig { - def fromConfigFile(file: String): Option[BatchConfig] = - if (new File(file).exists()) { - using(new FileInputStream(file))(in => Some(fromConfig(JsonMethods.parse(in)))) - } else { - None - } + def fromConfigFile(file: Path): Option[BatchConfig] = + if (!file.toFile.exists()) None + else using(Files.newInputStream(file))(in => Some(fromConfig(JsonMethods.parse(in)))) def fromConfig(config: JValue): BatchConfig = { implicit val formats: Formats = DefaultFormats @@ -21,4 +17,4 @@ object BatchConfig { } } -class BatchConfig(val batchId: Long, val jobGroupId: Long) +case class BatchConfig(batchId: Int, jobGroupId: Int) diff --git a/hail/src/main/scala/is/hail/services/Requester.scala b/hail/src/main/scala/is/hail/services/Requester.scala deleted file mode 100644 index fcfbd808ba8..00000000000 --- a/hail/src/main/scala/is/hail/services/Requester.scala +++ /dev/null @@ -1,168 +0,0 @@ -package is.hail.services - -import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext -import is.hail.shadedazure.com.azure.identity.{ - ClientSecretCredential, ClientSecretCredentialBuilder, -} -import is.hail.utils._ - -import scala.collection.JavaConverters._ - -import java.io.{FileInputStream, InputStream} - -import com.google.auth.oauth2.ServiceAccountCredentials -import org.apache.commons.io.IOUtils -import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} -import org.apache.http.client.config.RequestConfig -import org.apache.http.client.methods.HttpUriRequest -import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} -import org.apache.http.util.EntityUtils -import org.apache.log4j.{LogManager, Logger} -import org.json4s.{Formats, JValue} -import org.json4s.jackson.JsonMethods - -abstract class CloudCredentials { - def accessToken(): String -} - -class GoogleCloudCredentials(gsaKeyPath: String) extends CloudCredentials { - private[this] val credentials = using(new FileInputStream(gsaKeyPath)) { is => - ServiceAccountCredentials - .fromStream(is) - .createScoped("openid", "email", "profile") - } - - override def accessToken(): String = { - credentials.refreshIfExpired() - credentials.getAccessToken.getTokenValue - } -} - -class AzureCloudCredentials(credentialsPath: String) extends CloudCredentials { - private[this] val credentials: ClientSecretCredential = - using(new FileInputStream(credentialsPath)) { is => - implicit val formats: Formats = defaultJSONFormats - val kvs = JsonMethods.parse(is) - val appId = (kvs \ "appId").extract[String] - val password = (kvs \ "password").extract[String] - val tenant = (kvs \ "tenant").extract[String] - - new ClientSecretCredentialBuilder() - .clientId(appId) - .clientSecret(password) - .tenantId(tenant) - .build() - } - - override def accessToken(): String = { - val context = new TokenRequestContext() - context.setScopes(Array(System.getenv("HAIL_AZURE_OAUTH_SCOPE")).toList.asJava) - credentials.getToken(context).block.getToken - } -} - -class ClientResponseException( - val status: Int, - message: String, - cause: Throwable, -) extends Exception(message, cause) { - def this(statusCode: Int) = this(statusCode, null, null) - - def this(statusCode: Int, message: String) = this(statusCode, message, null) -} - -object Requester { - private val log: Logger = LogManager.getLogger("Requester") - private[this] val TIMEOUT_MS = 5 * 1000 - - val httpClient: CloseableHttpClient = { - log.info("creating HttpClient") - val requestConfig = RequestConfig.custom() - .setConnectTimeout(TIMEOUT_MS) - .setConnectionRequestTimeout(TIMEOUT_MS) - .setSocketTimeout(TIMEOUT_MS) - .build() - try { - HttpClients.custom() - .setSSLContext(tls.getSSLContext) - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) - .setDefaultRequestConfig(requestConfig) - .build() - } catch { - case _: NoSSLConfigFound => - log.info("creating HttpClient with no SSL Context") - HttpClients.custom() - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) - .setDefaultRequestConfig(requestConfig) - .build() - } - } - - def fromCredentialsFile(credentialsPath: String) = { - val credentials = sys.env.get("HAIL_CLOUD") match { - case Some("gcp") => new GoogleCloudCredentials(credentialsPath) - case Some("azure") => new AzureCloudCredentials(credentialsPath) - case Some(cloud) => - throw new IllegalArgumentException(s"Bad cloud: $cloud") - case None => - throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") - } - new Requester(credentials) - } -} - -class Requester( - val credentials: CloudCredentials -) { - import Requester._ - - def requestWithHandler[T >: Null](req: HttpUriRequest, body: HttpEntity, f: InputStream => T) - : T = { - log.info(s"request ${req.getMethod} ${req.getURI}") - - if (body != null) - req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(body) - - val token = credentials.accessToken() - req.addHeader("Authorization", s"Bearer $token") - - retryTransientErrors { - using(httpClient.execute(req)) { resp => - val statusCode = resp.getStatusLine.getStatusCode - log.info(s"request ${req.getMethod} ${req.getURI} response $statusCode") - if (statusCode < 200 || statusCode >= 300) { - val entity = resp.getEntity - val message = - if (entity != null) - EntityUtils.toString(entity) - else - null - throw new ClientResponseException(statusCode, message) - } - val entity: HttpEntity = resp.getEntity - if (entity != null) { - using(entity.getContent)(f) - } else - null - } - } - } - - def requestAsByteStream(req: HttpUriRequest, body: HttpEntity = null): Array[Byte] = - requestWithHandler(req, body, IOUtils.toByteArray) - - def request(req: HttpUriRequest, body: HttpEntity = null): JValue = - requestWithHandler( - req, - body, - { content => - val s = IOUtils.toByteArray(content) - if (s.isEmpty) - null - else - JsonMethods.parse(new String(s)) - }, - ) -} diff --git a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala b/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala deleted file mode 100644 index e6ea6ea1aa2..00000000000 --- a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala +++ /dev/null @@ -1,263 +0,0 @@ -package is.hail.services.batch_client - -import is.hail.expr.ir.ByteArrayBuilder -import is.hail.services._ -import is.hail.utils._ - -import scala.util.Random - -import java.nio.charset.StandardCharsets - -import org.apache.http.HttpEntity -import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost} -import org.apache.http.entity.{ByteArrayEntity, ContentType, StringEntity} -import org.apache.log4j.{LogManager, Logger} -import org.json4s.{DefaultFormats, Formats, JInt, JObject, JString, JValue} -import org.json4s.jackson.JsonMethods - -class NoBodyException(message: String, cause: Throwable) extends Exception(message, cause) { - def this() = this(null, null) - - def this(message: String) = this(message, null) -} - -object BatchClient { - lazy val log: Logger = LogManager.getLogger("BatchClient") -} - -class BatchClient( - deployConfig: DeployConfig, - requester: Requester, -) { - - def this(credentialsPath: String) = - this(DeployConfig.get, Requester.fromCredentialsFile(credentialsPath)) - - import BatchClient._ - import requester.request - - private[this] val baseUrl = deployConfig.baseUrl("batch") - - def get(path: String): JValue = - request(new HttpGet(s"$baseUrl$path")) - - def post(path: String, body: HttpEntity): JValue = - request(new HttpPost(s"$baseUrl$path"), body = body) - - def post(path: String, json: JValue = null): JValue = - post( - path, - if (json != null) - new StringEntity( - JsonMethods.compact(json), - ContentType.create("application/json"), - ) - else - null, - ) - - def patch(path: String): JValue = - request(new HttpPatch(s"$baseUrl$path")) - - def delete(path: String, token: String): JValue = - request(new HttpDelete(s"$baseUrl$path")) - - def update(batchID: Long, token: String, jobGroup: JObject, jobs: IndexedSeq[JObject]) - : (Long, Long) = { - implicit val formats: Formats = DefaultFormats - - val updateJson = - JObject("n_jobs" -> JInt(jobs.length), "n_job_groups" -> JInt(1), "token" -> JString(token)) - val jobGroupSpec = specBytes(jobGroup) - val jobBunches = createBunches(jobs) - val updateIDAndJobGroupId = - if (jobBunches.length == 1 && jobBunches(0).length + jobGroupSpec.length < 1024 * 1024) { - val b = new ByteArrayBuilder() - b ++= "{\"job_groups\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, Array(jobGroupSpec)) - b ++= ",\"bunch\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, jobBunches(0)) - b ++= ",\"update\":".getBytes(StandardCharsets.UTF_8) - b ++= JsonMethods.compact(updateJson).getBytes(StandardCharsets.UTF_8) - b += '}' - val data = b.result() - val resp = retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/update-fast", - new ByteArrayEntity(data, ContentType.create("application/json")), - ) - } - b.clear() - ((resp \ "update_id").extract[Long], (resp \ "start_job_group_id").extract[Long]) - } else { - val resp = retryTransientErrors { - post(s"/api/v1alpha/batches/$batchID/updates/create", json = updateJson) - } - val updateID = (resp \ "update_id").extract[Long] - val startJobGroupId = (resp \ "start_job_group_id").extract[Long] - - val b = new ByteArrayBuilder() - b ++= "[".getBytes(StandardCharsets.UTF_8) - b ++= jobGroupSpec - b ++= "]".getBytes(StandardCharsets.UTF_8) - retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/updates/$updateID/job-groups/create", - new ByteArrayEntity(b.result(), ContentType.create("application/json")), - ) - } - - b.clear() - var i = 0 - while (i < jobBunches.length) { - addBunchBytes(b, jobBunches(i)) - val data = b.result() - retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/updates/$updateID/jobs/create", - new ByteArrayEntity( - data, - ContentType.create("application/json"), - ), - ) - } - b.clear() - i += 1 - } - - retryTransientErrors { - patch(s"/api/v1alpha/batches/$batchID/updates/$updateID/commit") - } - (updateID, startJobGroupId) - } - - log.info(s"run: created update $updateIDAndJobGroupId for batch $batchID") - updateIDAndJobGroupId - } - - def create(batchJson: JObject, jobs: IndexedSeq[JObject]): Long = { - implicit val formats: Formats = DefaultFormats - - val bunches = createBunches(jobs) - val batchID = if (bunches.length == 1) { - val bunch = bunches(0) - val b = new ByteArrayBuilder() - b ++= "{\"bunch\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, bunch) - b ++= ",\"batch\":".getBytes(StandardCharsets.UTF_8) - b ++= JsonMethods.compact(batchJson).getBytes(StandardCharsets.UTF_8) - b += '}' - val data = b.result() - val resp = retryTransientErrors { - post( - "/api/v1alpha/batches/create-fast", - new ByteArrayEntity(data, ContentType.create("application/json")), - ) - } - b.clear() - (resp \ "id").extract[Long] - } else { - val resp = retryTransientErrors(post("/api/v1alpha/batches/create", json = batchJson)) - val batchID = (resp \ "id").extract[Long] - - val b = new ByteArrayBuilder() - - var i = 0 - while (i < bunches.length) { - addBunchBytes(b, bunches(i)) - val data = b.result() - retryTransientErrors { - post( - s"/api/v1alpha/batches/$batchID/jobs/create", - new ByteArrayEntity( - data, - ContentType.create("application/json"), - ), - ) - } - b.clear() - i += 1 - } - - retryTransientErrors(patch(s"/api/v1alpha/batches/$batchID/close")) - batchID - } - log.info(s"run: created batch $batchID") - batchID - } - - def run(batchJson: JObject, jobs: IndexedSeq[JObject]): JValue = { - val batchID = create(batchJson, jobs) - waitForJobGroup(batchID, 0L) - } - - def waitForJobGroup(batchID: Long, jobGroupId: Long): JValue = { - implicit val formats: Formats = DefaultFormats - - Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms - - val start = System.nanoTime() - - while (true) { - val jobGroup = - retryTransientErrors(get(s"/api/v1alpha/batches/$batchID/job-groups/$jobGroupId")) - if ((jobGroup \ "complete").extract[Boolean]) - return jobGroup - - // wait 10% of duration so far - // at least, 50ms - // at most, 5s - val now = System.nanoTime() - val elapsed = now - start - val d = math.max( - math.min( - (0.1 * (0.8 + Random.nextFloat() * 0.4) * (elapsed / 1000.0 / 1000)).toInt, - 5000, - ), - 50, - ) - Thread.sleep(d) - } - - throw new AssertionError("unreachable") - } - - private def createBunches(jobs: IndexedSeq[JObject]): BoxedArrayBuilder[Array[Array[Byte]]] = { - val bunches = new BoxedArrayBuilder[Array[Array[Byte]]]() - val bunchb = new BoxedArrayBuilder[Array[Byte]]() - - var i = 0 - var size = 0 - while (i < jobs.length) { - val jobBytes = specBytes(jobs(i)) - if (size + jobBytes.length > 1024 * 1024) { - bunches += bunchb.result() - bunchb.clear() - size = 0 - } - bunchb += jobBytes - size += jobBytes.length - i += 1 - } - assert(bunchb.size > 0) - - bunches += bunchb.result() - bunchb.clear() - bunches - } - - private def specBytes(obj: JObject): Array[Byte] = - JsonMethods.compact(obj).getBytes(StandardCharsets.UTF_8) - - private def addBunchBytes(b: ByteArrayBuilder, bunch: Array[Array[Byte]]): Unit = { - var j = 0 - b += '[' - while (j < bunch.length) { - if (j > 0) - b += ',' - b ++= bunch(j) - j += 1 - } - b += ']' - } -} diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala new file mode 100644 index 00000000000..986cb7be658 --- /dev/null +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -0,0 +1,98 @@ +package is.hail.services + +import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials} +import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials +import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials +import is.hail.shadedazure.com.azure.core.credential.{TokenCredential, TokenRequestContext} +import is.hail.shadedazure.com.azure.identity.{ + ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, +} +import is.hail.utils.{defaultJSONFormats, using} +import org.json4s.Formats +import org.json4s.jackson.JsonMethods + +import java.io.Serializable +import java.nio.file.{Files, Path} +import scala.collection.JavaConverters._ + +object oauth2 { + + sealed trait CloudCredentials extends Product with Serializable { + def accessToken(scopes: IndexedSeq[String]): String + } + + def CloudCredentials(credentialsPath: Path, env: Map[String, String] = sys.env) + : CloudCredentials = + env.get("HAIL_CLOUD") match { + case Some("gcp") => GoogleCloudCredentials(Some(credentialsPath)) + case Some("azure") => AzureCloudCredentials(Some(credentialsPath)) + case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'") + case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + } + + def CloudScopes(env: Map[String, String] = sys.env): Array[String] = + env.get("HAIL_CLOUD") match { + case Some("gcp") => Array("openid", "email", "profile") + case Some("azure") => sys.env.get("HAIL_AZURE_OAUTH_SCOPE").toArray + case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.") + case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + } + + case class GoogleCloudCredentials(value: GoogleCredentials) extends CloudCredentials { + override def accessToken(scopes: IndexedSeq[String]): String = { + value.refreshIfExpired() + value.createScoped(scopes.asJava).getAccessToken.getTokenValue + } + } + + object GoogleCloudCredentials { + object EnvVars { + val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS" + } + + def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): GoogleCloudCredentials = + GoogleCloudCredentials( + keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match { + case Some(path) => using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream) + case None => GoogleCredentials.getApplicationDefault + } + ) + } + + sealed trait AzureCloudCredentials extends CloudCredentials { + def value: TokenCredential + + override def accessToken(scopes: IndexedSeq[String]): String = + value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava)).getToken + } + + object AzureCloudCredentials { + object EnvVars { + val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS" + } + + def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): AzureCloudCredentials = + keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match { + case Some(path) => AzureClientSecretCredentials(path) + case None => AzureDefaultCredentials + } + } + + private case object AzureDefaultCredentials extends AzureCloudCredentials { + @transient override lazy val value: TokenCredential = + new DefaultAzureCredentialBuilder().build() + } + + private case class AzureClientSecretCredentials(path: Path) extends AzureCloudCredentials { + @transient override lazy val value: TokenCredential = + using(Files.newInputStream(path)) { is => + implicit val fmts: Formats = defaultJSONFormats + val kvs = JsonMethods.parse(is) + new ClientSecretCredentialBuilder() + .clientId((kvs \ "appId").extract[String]) + .clientSecret((kvs \ "password").extract[String]) + .tenantId((kvs \ "tenant").extract[String]) + .build() + } + } +} diff --git a/hail/src/main/scala/is/hail/services/package.scala b/hail/src/main/scala/is/hail/services/package.scala index 161448ef102..c0369af14b6 100644 --- a/hail/src/main/scala/is/hail/services/package.scala +++ b/hail/src/main/scala/is/hail/services/package.scala @@ -4,13 +4,13 @@ import is.hail.shadedazure.com.azure.storage.common.implementation.Constants import is.hail.utils._ import scala.util.Random - import java.io._ import java.net._ - import com.google.api.client.googleapis.json.GoogleJsonResponseException import com.google.api.client.http.HttpResponseException import com.google.cloud.storage.StorageException +import is.hail.services.requests.ClientResponseException + import javax.net.ssl.SSLException import org.apache.http.{ConnectionClosedException, NoHttpResponseException} import org.apache.http.conn.HttpHostConnectException diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala new file mode 100644 index 00000000000..05a13ed057c --- /dev/null +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -0,0 +1,101 @@ +package is.hail.services + +import is.hail.services.oauth2.{CloudCredentials, CloudScopes} +import is.hail.utils.{log, _} +import org.apache.http.client.config.RequestConfig +import org.apache.http.client.methods.{HttpGet, HttpPatch, HttpPost, HttpUriRequest} +import org.apache.http.entity.ContentType.APPLICATION_JSON +import org.apache.http.entity.StringEntity +import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} +import org.apache.http.util.EntityUtils +import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} +import org.json4s.JValue +import org.json4s.JsonAST.JNothing +import org.json4s.jackson.JsonMethods + +import java.net.URL +import java.nio.file.Path + +object requests { + + class ClientResponseException(val status: Int, message: String) extends Exception(message) + + trait Requester extends AutoCloseable { + def get(route: String): JValue + def post(route: String, body: JValue): JValue + def post(route: String, body: HttpEntity): JValue + def patch(route: String): JValue + } + + private[this] val TIMEOUT_MS = 5 * 1000 + + def BatchServiceRequester(conf: DeployConfig, keyFile: Path, env: Map[String, String] = sys.env) + : Requester = + Requester( + new URL(conf.baseUrl("batch")), + CloudCredentials(keyFile, env), + CloudScopes(env), + ) + + def Requester(baseUrl: URL, cred: CloudCredentials, scopes: IndexedSeq[String]): Requester = { + + val httpClient: CloseableHttpClient = { + log.info("creating HttpClient") + val requestConfig = RequestConfig.custom() + .setConnectTimeout(TIMEOUT_MS) + .setConnectionRequestTimeout(TIMEOUT_MS) + .setSocketTimeout(TIMEOUT_MS) + .build() + try { + HttpClients.custom() + .setSSLContext(tls.getSSLContext) + .setMaxConnPerRoute(20) + .setMaxConnTotal(100) + .setDefaultRequestConfig(requestConfig) + .build() + } catch { + case _: NoSSLConfigFound => + log.info("creating HttpClient with no SSL Context") + HttpClients.custom() + .setMaxConnPerRoute(20) + .setMaxConnTotal(100) + .setDefaultRequestConfig(requestConfig) + .build() + } + } + + def request(req: HttpUriRequest, body: Option[HttpEntity] = None): JValue = { + log.info(s"request ${req.getMethod} ${req.getURI}") + req.addHeader("Authorization", s"Bearer ${cred.accessToken(scopes)}") + body.foreach(entity => req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(entity)) + retryTransientErrors { + using(httpClient.execute(req)) { resp => + val statusCode = resp.getStatusLine.getStatusCode + log.info(s"request ${req.getMethod} ${req.getURI} response $statusCode") + val message = Option(resp.getEntity).map(EntityUtils.toString) + if (statusCode < 200 || statusCode >= 300) { + throw new ClientResponseException(statusCode, message.orNull) + } + message.map(JsonMethods.parse(_)).getOrElse(JNothing) + } + } + } + + new Requester with Logging { + override def get(route: String): JValue = + request(new HttpGet(s"$baseUrl$route")) + + override def post(route: String, body: JValue): JValue = + post(route, new StringEntity(JsonMethods.compact(body), APPLICATION_JSON)) + + override def post(route: String, body: HttpEntity): JValue = + request(new HttpPost(s"$baseUrl$route"), Some(body)) + + override def patch(route: String): JValue = + request(new HttpPatch(s"$baseUrl$route")) + + override def close(): Unit = + httpClient.close() + } + } +} diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 14a9a1689d9..faddc4a6ce4 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -2,12 +2,9 @@ package is.hail.backend import is.hail.asm4s.HailClassLoader import is.hail.backend.service.{ServiceBackend, ServiceBackendRPCPayload} -import is.hail.services.batch_client.BatchClient +import is.hail.services.JobGroupStates.Success +import is.hail.services._ import is.hail.utils.tokenUrlSafe - -import scala.reflect.io.{Directory, Path} - -import org.json4s.{JArray, JBool, JInt, JObject, JString} import org.mockito.ArgumentMatchersSugar.any import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when @@ -15,6 +12,9 @@ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test +import scala.reflect.io.{Directory, Path} +import scala.util.Random + class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { @Test def testCreateJobPayload(): Unit = @@ -23,8 +23,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { val backend = ServiceBackend( - jarLocation = - classOf[ServiceBackend].getProtectionDomain.getCodeSource.getLocation.getPath, + jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", name = "name", theHailClassLoader = new HailClassLoader(getClass.getClassLoader), batchClient, @@ -42,30 +41,39 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { // - the number of jobs matches the number of partitions, and // - each job is created in the specified region, and // - each job's resource configuration matches the rpc config - when(batchClient.create(any[JObject], any[IndexedSeq[JObject]])) thenAnswer { - (batch: JObject, jobs: IndexedSeq[JObject]) => - batch \ "billing_project" shouldBe JString(rpcPayload.billing_project) - batch \ "n_jobs" shouldBe JInt(contexts.length) + val batchId = Random.nextInt() + + when(batchClient.newBatch(any[BatchRequest])) thenAnswer { + (batchRequest: BatchRequest) => + batchRequest.billing_project shouldEqual rpcPayload.billing_project + batchId + } + + when(batchClient.newJobGroup(any[Int], any[String], any[JobGroupRequest], any[IndexedSeq[JobRequest]])) thenAnswer { + (_: Int, _: String, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) => + jobGroup.job_group_id shouldBe 1 + jobGroup.absolute_parent_id shouldBe 0 jobs.length shouldEqual contexts.length jobs.foreach { payload => - payload \ "regions" shouldBe JArray(rpcPayload.regions.map(JString).toList) - - payload \ "resources" shouldBe JObject( - "preemptible" -> JBool(true), - "cpu" -> JString(rpcPayload.worker_cores), - "memory" -> JString(rpcPayload.worker_memory), - "storage" -> JString(rpcPayload.storage), + payload.regions shouldBe rpcPayload.regions + payload.resources shouldBe Some( + JobResources( + preemptible = true, + cpu = Some(rpcPayload.worker_cores), + memory = Some(rpcPayload.worker_memory), + storage = Some(rpcPayload.storage) + ) ) } - 37L + (batchId, 37) } // the service backend expects that each job write its output to a well-known // location when it finishes. - when(batchClient.waitForJobGroup(any[Long], any[Long])) thenAnswer { - (batchId: Long, jobGroupId: Long) => + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (batchId: Int, jobGroupId: Int) => batchId shouldEqual 37L jobGroupId shouldEqual 1L @@ -76,81 +84,17 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { resultsDir.createDirectory() for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JObject("state" -> JString("success")) - } - - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) - - failure.foreach(throw _) - - batchClient.create(any[JObject], any[IndexedSeq[JObject]]) wasCalled once - } - - @Test def testUpdateJobPayload(): Unit = - withMockDriverContext { config => - val batchClient = mock[BatchClient] - - val backend = - ServiceBackend( - jarLocation = - classOf[ServiceBackend].getProtectionDomain.getCodeSource.getLocation.getPath, - name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - batchClient, - batchId = Some(23L), - jobGroupId = None, - scratchDir = config.remote_tmpdir, - rpcConfig = config, - sys.env + ("HAIL_CLOUD" -> "gcp"), - ) - - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that the service backend - // - updates the batch with the correct billing project, and - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - when( - batchClient.update(any[Long], any[String], any[JObject], any[IndexedSeq[JObject]]) - ) thenAnswer { - (batchId: Long, _: String, _: JObject, jobs: IndexedSeq[JObject]) => - batchId shouldEqual 23L - - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload \ "regions" shouldBe JArray(config.regions.map(JString).toList) - - payload \ "resources" shouldBe JObject( - "preemptible" -> JBool(true), - "cpu" -> JString(config.worker_cores), - "memory" -> JString(config.worker_memory), - "storage" -> JString(config.storage), - ) - } - - (2L, 3L) - } - - when(batchClient.waitForJobGroup(any[Long], any[Long])) thenAnswer { - (batchId: Long, jobGroupId: Long) => - batchId shouldEqual 23L - jobGroupId shouldEqual 3L - - val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe - - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JObject("state" -> JString("success")) + JobGroupResponse( + batch_id = batchId, + job_group_id = jobGroupId, + state = Success, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = contexts.length, + n_failed = 0, + n_cancelled = 0 + ) } val (failure, _) = @@ -163,13 +107,8 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { failure.foreach(throw _) - batchClient.create(any[JObject], any[IndexedSeq[JObject]]) wasNever called - batchClient.update( - any[Long], - any[String], - any[JObject], - any[IndexedSeq[JObject]], - ) wasCalled once + batchClient.newBatch(any) wasCalled once + batchClient.newJobGroup(any, any, any, any) wasCalled once } def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = diff --git a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala index d6f3b8cbc4f..3dd7be5b944 100644 --- a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala @@ -1,7 +1,8 @@ package is.hail.io.fs -import java.io.FileInputStream +import is.hail.services.oauth2.AzureCloudCredentials +import java.io.FileInputStream import org.apache.commons.io.IOUtils import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} @@ -17,16 +18,7 @@ class AzureStorageFSSuite extends FSSuite { } } - lazy val fs = { - val aac = System.getenv("AZURE_APPLICATION_CREDENTIALS") - if (aac == null) { - new AzureStorageFS() - } else { - new AzureStorageFS( - Some(new String(IOUtils.toByteArray(new FileInputStream(aac)))) - ) - } - } + lazy val fs = new AzureStorageFS(AzureCloudCredentials(None)) @Test def testMakeQualified(): Unit = { val qualifiedFileName = "https://account.blob.core.windows.net/container/path" diff --git a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala index 4f6e654b87c..f2c911a3fd2 100644 --- a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala @@ -1,8 +1,6 @@ package is.hail.io.fs -import java.io.FileInputStream - -import org.apache.commons.io.IOUtils +import is.hail.services.oauth2.GoogleCloudCredentials import org.scalatestplus.testng.TestNGSuite import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} @@ -18,16 +16,7 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite { } } - lazy val fs = { - val gac = System.getenv("GOOGLE_APPLICATION_CREDENTIALS") - if (gac == null) { - new GoogleStorageFS() - } else { - new GoogleStorageFS( - Some(new String(IOUtils.toByteArray(new FileInputStream(gac)))) - ) - } - } + lazy val fs = new GoogleStorageFS(GoogleCloudCredentials(None), None) @Test def testMakeQualified(): Unit = { val qualifiedFileName = "gs://bucket/path" diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala new file mode 100644 index 00000000000..ddbbec0573e --- /dev/null +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -0,0 +1,40 @@ +package is.hail.services + +import is.hail.utils._ +import org.scalatestplus.testng.TestNGSuite +import org.testng.annotations.Test + +import java.nio.file.Path + +class BatchClientSuite extends TestNGSuite { + @Test def testBasic(): Unit = { + using(BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json"))) { client => + + val jobGroup = client.run( + BatchRequest( + billing_project = "test", + n_jobs = 1, + token = tokenUrlSafe, + ), + JobGroupRequest( + job_group_id = 0, + absolute_parent_id = 0, + ), + FastSeq( + JobRequest( + job_id = 0, + always_run = false, + in_update_job_group_id = 0, + in_update_parent_ids = Array(), + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "'hello, world!"), + ), + ) + ), + ) + + assert(jobGroup.state == JobGroupStates.Success) + } + } +} diff --git a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala deleted file mode 100644 index 521d40046e4..00000000000 --- a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala +++ /dev/null @@ -1,40 +0,0 @@ -package is.hail.services.batch_client - -import is.hail.utils._ - -import org.json4s.{DefaultFormats, Formats} -import org.json4s.JsonAST._ -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class BatchClientSuite extends TestNGSuite { - @Test def testBasic(): Unit = { - val client = new BatchClient("/test-gsa-key/key.json") - val token = tokenUrlSafe - val batch = client.run( - JObject( - "billing_project" -> JString("test"), - "n_jobs" -> JInt(1), - "token" -> JString(token), - ), - FastSeq( - JObject( - "always_run" -> JBool(false), - "job_id" -> JInt(0), - "parent_ids" -> JArray(List()), - "process" -> JObject( - "image" -> JString("ubuntu:22.04"), - "command" -> JArray(List( - JString("/bin/bash"), - JString("-c"), - JString("echo 'Hello, world!'"), - )), - "type" -> JString("docker"), - ), - ) - ), - ) - implicit val formats: Formats = DefaultFormats - assert((batch \ "state").extract[String] == "success") - } -} From fe19971f31b0588676f79e33c0fab7b56711a207 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Fri, 13 Sep 2024 13:58:45 -0400 Subject: [PATCH 02/20] fix gcs oauth2 scopes --- .../src/main/scala/is/hail/services/oauth2.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala index 986cb7be658..657521600ce 100644 --- a/hail/src/main/scala/is/hail/services/oauth2.scala +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -32,10 +32,18 @@ object oauth2 { def CloudScopes(env: Map[String, String] = sys.env): Array[String] = env.get("HAIL_CLOUD") match { - case Some("gcp") => Array("openid", "email", "profile") - case Some("azure") => sys.env.get("HAIL_AZURE_OAUTH_SCOPE").toArray - case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.") - case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + case Some("gcp") => + Array( + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + "openid", + ) + case Some("azure") => + sys.env.get("HAIL_AZURE_OAUTH_SCOPE").toArray + case Some(cloud) => + throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.") + case None => + throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") } case class GoogleCloudCredentials(value: GoogleCredentials) extends CloudCredentials { From b01378ab86e9ce39b680c92a26b82eebeb52520c Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Fri, 13 Sep 2024 16:46:40 -0400 Subject: [PATCH 03/20] make backend closable + fix mocked test --- hail/python/hail/backend/py4j_backend.py | 2 +- hail/python/hail/context.py | 2 +- hail/python/hailtop/config/user_config.py | 5 +- hail/src/main/scala/is/hail/HailContext.scala | 2 +- .../main/scala/is/hail/backend/Backend.scala | 4 +- .../scala/is/hail/backend/BackendServer.scala | 12 +- .../is/hail/backend/local/LocalBackend.scala | 2 +- .../hail/backend/service/ServiceBackend.scala | 10 +- .../is/hail/backend/service/Worker.scala | 22 +- .../is/hail/backend/spark/SparkBackend.scala | 8 +- .../scala/is/hail/services/BatchClient.scala | 5 +- .../is/hail/backend/ServiceBackendSuite.scala | 227 ++++++++++-------- .../is/hail/io/fs/AzureStorageFSSuite.scala | 2 - .../is/hail/io/fs/GoogleStorageFSSuite.scala | 1 + .../is/hail/services/BatchClientSuite.scala | 9 +- 15 files changed, 162 insertions(+), 151 deletions(-) diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 49a7d9b8e14..4e216e22aad 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -306,7 +306,7 @@ def _to_java_blockmatrix_ir(self, ir): return self._parse_blockmatrix_ir(self._render_ir(ir)) def stop(self): - self._backend_server.stop() + self._backend_server.close() self._jhc.stop() self._jhc = None self._registered_ir_function_names = set() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 5258f27fbc1..56897eec987 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -151,7 +151,7 @@ def default_reference(self, value): def stop(self): assert self._backend - self._backend.stop() + self._backend.close() self._backend = None Env._hc = None Env._dummy_table = None diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index 752031eb166..a6d003b4c12 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -144,6 +144,5 @@ def get_remote_tmpdir( raise ValueError( f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure' ) - if remote_tmpdir[-1] != '/': - remote_tmpdir += '/' - return remote_tmpdir + + return remote_tmpdir[:-1] if remote_tmpdir[-1] == '/' else remote_tmpdir diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index 1d2fe403762..3de89ce13cd 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -136,7 +136,7 @@ object HailContext { def stop(): Unit = synchronized { IRFunctionRegistry.clearUserFunctions() - backend.stop() + backend.close() theContext = null } diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 2b32a471025..e0444e371e2 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -74,7 +74,7 @@ trait BackendContext { def executionCache: ExecutionCache } -abstract class Backend { +abstract class Backend extends Closeable { // From https://github.com/hail-is/hail/issues/14580 : // IR can get quite big, especially as it can contain an arbitrary // amount of encoded literals from the user's python session. This @@ -123,7 +123,7 @@ abstract class Backend { f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) - def stop(): Unit + def close(): Unit def asSpark(op: String): SparkBackend = fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index e23d4a3d1e3..0217fc1d8d7 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -4,16 +4,16 @@ import is.hail.expr.ir.{IRParser, IRParserEnvironment} import is.hail.utils._ import scala.util.control.NonFatal - import java.net.InetSocketAddress import java.nio.charset.StandardCharsets import java.util.concurrent._ - import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.json4s._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.compact +import java.io.Closeable + case class IRTypePayload(ir: String) case class LoadReferencesFromDatasetPayload(path: String) @@ -31,11 +31,7 @@ case class ParseVCFMetadataPayload(path: String) case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) -object BackendServer { - def apply(backend: Backend) = new BackendServer(backend) -} - -class BackendServer(backend: Backend) { +class BackendServer(backend: Backend) extends Closeable { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) private[this] val handler = new BackendHttpHandler(backend) @@ -77,7 +73,7 @@ class BackendServer(backend: Backend) { def start(): Unit = thread.start() - def stop(): Unit = + override def close(): Unit = httpServer.stop(10) } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index f2c130ae639..2b43df11419 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -137,7 +137,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache def defaultParallelism: Int = 1 - def stop(): Unit = LocalBackend.stop() + def close(): Unit = LocalBackend.stop() private[this] def _jvmLowerAndExecute( ctx: ExecuteContext, diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 0ee85562352..668a3ec2286 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -122,7 +122,7 @@ class ServiceBackend( val tmpdir: String, val fs: FS, val serviceBackendContext: ServiceBackendContext, - val scratchDir: String = sys.env.get("HAIL_WORKER_SCRATCH_DIR").getOrElse(""), + val scratchDir: String, ) extends Backend with BackendWithNoCodeCache { import ServiceBackend.log @@ -165,7 +165,7 @@ class ServiceBackend( val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] val n = collection.length val token = tokenUrlSafe - val root = s"${backendContext.remoteTmpDir}parallelizeAndComputeWithIndex/$token" + val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n") log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts") @@ -303,7 +303,7 @@ class ServiceBackend( r } - def stop(): Unit = { + override def close(): Unit = { executor.shutdownNow() batchClient.close() } @@ -421,9 +421,7 @@ object ServiceBackendAPI { HailFeatureFlags.fromEnv(), ) ) - val deployConfig = DeployConfig.fromConfigFile( - s"$scratchDir/secrets/deploy-config/deploy-config.json" - ) + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 3dcfa2a63b4..5a6ec9f9505 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -1,23 +1,22 @@ package is.hail.backend.service -import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags} import is.hail.asm4s._ import is.hail.backend.HailTaskContext import is.hail.io.fs._ import is.hail.services._ import is.hail.utils._ +import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags} +import org.apache.log4j.Logger -import scala.collection.mutable -import scala.concurrent.{Await, ExecutionContext, Future} -import scala.concurrent.duration.Duration -import scala.util.control.NonFatal import java.io._ import java.nio.charset._ +import java.nio.file.Path import java.util import java.util.{concurrent => javaConcurrent} -import org.apache.log4j.Logger - -import java.nio.file.Path +import scala.collection.mutable +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.util.control.NonFatal class ServiceTaskContext(val partitionId: Int) extends HailTaskContext { override def stageId(): Int = 0 @@ -113,9 +112,7 @@ object Worker { val n = argv(6).toInt val timer = new WorkerTimer() - val deployConfig = DeployConfig.fromConfigFile( - s"$scratchDir/secrets/deploy-config/deploy-config.json" - ) + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) @@ -167,7 +164,6 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - if (HailContext.isInitialized) { HailContext.get.backend = new ServiceBackend( null, @@ -180,6 +176,7 @@ object Worker { null, null, null, + scratchDir, ) } else { HailContext( @@ -195,6 +192,7 @@ object Worker { null, null, null, + scratchDir, ) ) } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index b6ded3487f3..2b4720cd10e 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -472,7 +472,10 @@ class SparkBackend( override def asSpark(op: String): SparkBackend = this - def stop(): Unit = SparkBackend.stop() + def close(): Unit = { + SparkBackend.stop() + longLifeTempFileManager.close() + } def startProgressBar(): Unit = ProgressBarBuilder.build(sc) @@ -761,9 +764,6 @@ class SparkBackend( RVDTableReader(RVD.unkeyed(rowPType, orderedCRDD), globalsLit, rt) } - def close(): Unit = - longLifeTempFileManager.close() - def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = { CanLowerEfficiently(ctx, inputIR) match { diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 780ee1c4819..5d458900484 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -86,8 +86,9 @@ object JobGroupStates { } object BatchClient { - def apply(deployConfig: DeployConfig, credentialsFile: Path): BatchClient = - new BatchClient(BatchServiceRequester(deployConfig, credentialsFile)) + def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env) + : BatchClient = + new BatchClient(BatchServiceRequester(deployConfig, credentialsFile, env)) } case class BatchClient private (req: Requester) extends Logging with AutoCloseable { diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index faddc4a6ce4..57b17a1c59c 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -1,123 +1,146 @@ package is.hail.backend +import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendRPCPayload} -import is.hail.services.JobGroupStates.Success +import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} import is.hail.services._ -import is.hail.utils.tokenUrlSafe +import is.hail.services.JobGroupStates.Success +import is.hail.utils.{tokenUrlSafe, using} + +import scala.reflect.io.{Directory, Path} +import scala.util.Random + +import java.io.Closeable + import org.mockito.ArgumentMatchersSugar.any import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when +import org.scalatest.OptionValues import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -import scala.reflect.io.{Directory, Path} -import scala.util.Random - -class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { +class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { rpcPayload => + withMockDriverContext { rpcConfig => val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, rpcConfig)) { backend => + val contexts = Array.tabulate(1)(_.toString.getBytes) + + // verify that the service backend + // - creates the batch with the correct billing project, and + // - the number of jobs matches the number of partitions, and + // - each job is created in the specified region, and + // - each job's resource configuration matches the rpc config + val batchId = Random.nextInt() + + when(batchClient.newBatch(any[BatchRequest])) thenAnswer { + (batchRequest: BatchRequest) => + batchRequest.billing_project shouldEqual rpcConfig.billing_project + batchRequest.n_jobs shouldBe 0 + batchRequest.attributes.get("name").value shouldBe backend.name + batchId + } - val backend = - ServiceBackend( - jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", - name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - batchClient, - batchId = None, - jobGroupId = None, - scratchDir = rpcPayload.remote_tmpdir, - rpcConfig = rpcPayload, - sys.env + ("HAIL_CLOUD" -> "gcp"), - ) - - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that the service backend - // - creates the batch with the correct billing project, and - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - val batchId = Random.nextInt() - - when(batchClient.newBatch(any[BatchRequest])) thenAnswer { - (batchRequest: BatchRequest) => - batchRequest.billing_project shouldEqual rpcPayload.billing_project - batchId - } - - when(batchClient.newJobGroup(any[Int], any[String], any[JobGroupRequest], any[IndexedSeq[JobRequest]])) thenAnswer { - (_: Int, _: String, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) => - jobGroup.job_group_id shouldBe 1 - jobGroup.absolute_parent_id shouldBe 0 - - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload.regions shouldBe rpcPayload.regions - payload.resources shouldBe Some( - JobResources( + when(batchClient.newJobGroup( + any[Int], + any[String], + any[JobGroupRequest], + any[IndexedSeq[JobRequest]], + )) thenAnswer { + (id: Int, _: String, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) => + id shouldBe batchId + jobGroup.job_group_id shouldBe 1 + jobGroup.absolute_parent_id shouldBe 0 + jobs.length shouldEqual contexts.length + jobs.foreach { payload => + payload.regions.value shouldBe rpcConfig.regions + payload.resources.value shouldBe JobResources( preemptible = true, - cpu = Some(rpcPayload.worker_cores), - memory = Some(rpcPayload.worker_memory), - storage = Some(rpcPayload.storage) + cpu = Some(rpcConfig.worker_cores), + memory = Some(rpcConfig.worker_memory), + storage = Some(rpcConfig.storage), ) - ) - } + } - (batchId, 37) - } + (37, 1) + } - // the service backend expects that each job write its output to a well-known - // location when it finishes. - when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { - (batchId: Int, jobGroupId: Int) => - batchId shouldEqual 37L - jobGroupId shouldEqual 1L - - val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe - - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JobGroupResponse( - batch_id = batchId, - job_group_id = jobGroupId, - state = Success, - complete = true, - n_jobs = contexts.length, - n_completed = contexts.length, - n_succeeded = contexts.length, - n_failed = 0, - n_cancelled = 0 - ) - } + // the service backend expects that each job write its output to a well-known + // location when it finishes. + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + id shouldEqual batchId + jobGroupId shouldEqual 1 + + val resultsDir = + Path(backend.serviceBackendContext.remoteTmpDir) / + "parallelizeAndComputeWithIndex" / + tokenUrlSafe + + resultsDir.createDirectory() + for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Success, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = contexts.length, + n_failed = 0, + n_cancelled = 0, + ) + } - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) + val (failure, _) = + backend.parallelizeAndComputeWithIndex( + backend.serviceBackendContext, + backend.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) - failure.foreach(throw _) + failure.foreach(throw _) - batchClient.newBatch(any) wasCalled once - batchClient.newJobGroup(any, any, any, any) wasCalled once + batchClient.newBatch(any) wasCalled once + batchClient.newJobGroup(any, any, any, any) wasCalled once + } } - def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = - withNewLocalTmpFolder { tmp => - // The `ServiceBackend` assumes credentials are installed to a well known location - val gcsKeyDir = tmp / "secrets" / "gsa-key" - gcsKeyDir.createDirectory() - (gcsKeyDir / "key.json").toFile.writeAll("password1234") + def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { + val flags = HailFeatureFlags.fromEnv() + val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) + new ServiceBackend( + jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", + name = "name", + theHailClassLoader = new HailClassLoader(getClass.getClassLoader), + batchClient = client, + curBatchId = None, + curJobGroupId = None, + flags = flags, + tmpdir = rpcConfig.tmp_dir, + fs = fs, + serviceBackendContext = + new ServiceBackendContext( + rpcConfig.billing_project, + rpcConfig.remote_tmpdir, + rpcConfig.worker_cores, + rpcConfig.worker_memory, + rpcConfig.storage, + rpcConfig.regions, + rpcConfig.cloudfuse_configs, + profile = false, + ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), + ), + scratchDir = rpcConfig.remote_tmpdir, + ) + } + def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = + using(LocalTmpFolder) { tmp => withObjectSpied[is.hail.utils.UtilsType] { // not obvious how to pull out `tokenUrlSafe` and inject this directory // using a spy is a hack and i don't particularly like it. @@ -125,8 +148,8 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { test { ServiceBackendRPCPayload( - tmp_dir = "", - remote_tmpdir = tmp.path + "/", // because raw strings... + tmp_dir = tmp.path, + remote_tmpdir = tmp.path, billing_project = "fancy", worker_cores = "128", worker_memory = "a lot.", @@ -142,10 +165,8 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito { } } - def withNewLocalTmpFolder[A](f: Directory => A): A = { - val tmp = Directory.makeTemp("hail-testing-tmp", "") - try f(tmp) - finally tmp.deleteRecursively() - } - + def LocalTmpFolder: Directory with Closeable = + new Directory(Directory.makeTemp("hail-testing-tmp").jfile) with Closeable { + override def close(): Unit = deleteRecursively() + } } diff --git a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala index 3dd7be5b944..edca1194921 100644 --- a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala @@ -2,8 +2,6 @@ package is.hail.io.fs import is.hail.services.oauth2.AzureCloudCredentials -import java.io.FileInputStream -import org.apache.commons.io.IOUtils import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} diff --git a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala index f2c911a3fd2..dcd019d2e94 100644 --- a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala @@ -1,6 +1,7 @@ package is.hail.io.fs import is.hail.services.oauth2.GoogleCloudCredentials + import org.scalatestplus.testng.TestNGSuite import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index ddbbec0573e..af1aaeced4d 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -1,15 +1,15 @@ package is.hail.services import is.hail.utils._ -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test import java.nio.file.Path +import org.scalatestplus.testng.TestNGSuite +import org.testng.annotations.Test + class BatchClientSuite extends TestNGSuite { - @Test def testBasic(): Unit = { + @Test def testBasic(): Unit = using(BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json"))) { client => - val jobGroup = client.run( BatchRequest( billing_project = "test", @@ -36,5 +36,4 @@ class BatchClientSuite extends TestNGSuite { assert(jobGroup.state == JobGroupStates.Success) } - } } From 4f252a07ecfad9e6febfa9f65a2931240f797189 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 16 Sep 2024 20:53:33 -0400 Subject: [PATCH 04/20] reformat --- .../scala/is/hail/backend/BackendServer.scala | 5 +++-- .../hail/backend/service/ServiceBackend.scala | 14 +++++++----- .../is/hail/backend/service/Worker.scala | 14 +++++++----- .../scala/is/hail/io/fs/AzureStorageFS.scala | 14 ++++++------ hail/src/main/scala/is/hail/io/fs/FS.scala | 3 ++- .../scala/is/hail/io/fs/GoogleStorageFS.scala | 22 +++++++++++-------- .../main/scala/is/hail/io/fs/RouterFS.scala | 3 ++- .../is/hail/io/fs/TerraAzureStorageFS.scala | 7 +++--- .../scala/is/hail/services/BatchClient.scala | 14 +++++++----- .../scala/is/hail/services/BatchConfig.scala | 5 +++-- .../main/scala/is/hail/services/oauth2.scala | 10 +++++---- .../main/scala/is/hail/services/package.scala | 5 +++-- .../scala/is/hail/services/requests.scala | 9 ++++---- 13 files changed, 73 insertions(+), 52 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 0217fc1d8d7..7ce224548c9 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -4,16 +4,17 @@ import is.hail.expr.ir.{IRParser, IRParserEnvironment} import is.hail.utils._ import scala.util.control.NonFatal + +import java.io.Closeable import java.net.InetSocketAddress import java.nio.charset.StandardCharsets import java.util.concurrent._ + import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.json4s._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.compact -import java.io.Closeable - case class IRTypePayload(ir: String) case class LoadReferencesFromDatasetPayload(path: String) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 668a3ec2286..c4bb2d40388 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -5,14 +5,17 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate -import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck} +import is.hail.expr.ir.{ + Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, + TableIR, TableReader, TypeCheck, +} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix -import is.hail.services.JobGroupStates.Failure import is.hail.services.{BatchClient, _} +import is.hail.services.JobGroupStates.Failure import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -22,17 +25,18 @@ import is.hail.variant.ReferenceGenome import scala.annotation.switch import scala.reflect.ClassTag + import java.io._ import java.nio.charset.StandardCharsets +import java.nio.file.Path import java.util.concurrent._ + import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import sourcecode.Enclosing -import java.nio.file.Path - class ServiceBackendContext( val billingProject: String, val remoteTmpDir: String, @@ -181,7 +185,7 @@ class ServiceBackend( val uploadContexts = executor.submit[Unit](() => retryTransientErrors { fs.writePDOS(s"$root/contexts") { os => - var o = 12L * n // 12L = sizeof(Long) + sizeof(Int) + var o = 12L * n // 12L = sizeof(Long) + sizeof(Int) collection.foreach { context => val len = context.length os.writeLong(o) diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 5a6ec9f9505..de608b26515 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -1,22 +1,24 @@ package is.hail.backend.service +import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags} import is.hail.asm4s._ import is.hail.backend.HailTaskContext import is.hail.io.fs._ import is.hail.services._ import is.hail.utils._ -import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags} -import org.apache.log4j.Logger + +import scala.collection.mutable +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.control.NonFatal import java.io._ import java.nio.charset._ import java.nio.file.Path import java.util import java.util.{concurrent => javaConcurrent} -import scala.collection.mutable -import scala.concurrent.duration.Duration -import scala.concurrent.{Await, ExecutionContext, Future} -import scala.util.control.NonFatal + +import org.apache.log4j.Logger class ServiceTaskContext(val partitionId: Int) extends HailTaskContext { override def stageId(): Int = 0 diff --git a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala index a4d99c63fe2..a640fc68419 100644 --- a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala @@ -5,21 +5,21 @@ import is.hail.services.oauth2.AzureCloudCredentials import is.hail.services.retryTransientErrors import is.hail.shadedazure.com.azure.core.credential.AzureSasCredential import is.hail.shadedazure.com.azure.core.util.HttpClientOptions +import is.hail.shadedazure.com.azure.storage.blob.{ + BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder, +} import is.hail.shadedazure.com.azure.storage.blob.models.{ BlobItem, BlobRange, BlobStorageException, ListBlobsOptions, } import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient -import is.hail.shadedazure.com.azure.storage.blob.{ - BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder, -} -import java.io.{ByteArrayOutputStream, FileNotFoundException, OutputStream} -import java.nio.file.Paths -import java.time.Duration import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import java.nio.file.Path + +import java.io.{ByteArrayOutputStream, FileNotFoundException, OutputStream} +import java.nio.file.{Path, Paths} +import java.time.Duration class AzureStorageFSURL( val account: String, diff --git a/hail/src/main/scala/is/hail/io/fs/FS.scala b/hail/src/main/scala/is/hail/io/fs/FS.scala index 441bde2cdfd..f7de8037b74 100644 --- a/hail/src/main/scala/is/hail/io/fs/FS.scala +++ b/hail/src/main/scala/is/hail/io/fs/FS.scala @@ -1,6 +1,5 @@ package is.hail.io.fs - import is.hail.HailContext import is.hail.backend.BroadcastValue import is.hail.io.compress.{BGzipInputStream, BGzipOutputStream} @@ -10,10 +9,12 @@ import is.hail.utils._ import scala.collection.mutable import scala.io.Source + import java.io._ import java.nio.ByteBuffer import java.nio.file.FileSystems import java.util.zip.GZIPOutputStream + import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream import org.apache.commons.io.IOUtils import org.apache.hadoop diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index 9ee4c1cad1e..6a2921f07de 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -1,21 +1,25 @@ package is.hail.io.fs -import com.google.api.client.googleapis.json.GoogleJsonResponseException -import com.google.cloud.http.HttpTransportOptions -import com.google.cloud.storage.Storage.{BlobGetOption, BlobListOption, BlobSourceOption, BlobWriteOption} -import com.google.cloud.storage.{Option => _, _} -import com.google.cloud.{ReadChannel, WriteChannel} import is.hail.HailFeatureFlags import is.hail.io.fs.FSUtil.dropTrailingSlash import is.hail.io.fs.GoogleStorageFS.RequesterPaysFailure -import is.hail.services.oauth2.GoogleCloudCredentials import is.hail.services.{isTransientError, retryTransientErrors} +import is.hail.services.oauth2.GoogleCloudCredentials import is.hail.utils._ +import scala.jdk.CollectionConverters._ + import java.io.{FileNotFoundException, IOException} import java.nio.ByteBuffer import java.nio.file.{Path, Paths} -import scala.jdk.CollectionConverters._ + +import com.google.api.client.googleapis.json.GoogleJsonResponseException +import com.google.cloud.{ReadChannel, WriteChannel} +import com.google.cloud.http.HttpTransportOptions +import com.google.cloud.storage.{Option => _, _} +import com.google.cloud.storage.Storage.{ + BlobGetOption, BlobListOption, BlobSourceOption, BlobWriteOption, +} case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { def addPathComponent(c: String): GoogleStorageFSURL = @@ -70,13 +74,13 @@ object GoogleStorageFS { case exc: StorageException => Option(exc.getMessage).exists { message => message == "userProjectMissing" || - (exc.getCode == 400 && message.contains("requester pays")) + (exc.getCode == 400 && message.contains("requester pays")) } case exc: GoogleJsonResponseException => Option(exc.getMessage).exists { message => message == "userProjectMissing" || - (exc.getStatusCode == 400 && message.contains("requester pays")) + (exc.getStatusCode == 400 && message.contains("requester pays")) } case _ => diff --git a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala index b61bad0d74b..d89985c0780 100644 --- a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala @@ -3,11 +3,12 @@ package is.hail.io.fs import is.hail.HailFeatureFlags import is.hail.services.oauth2.{AzureCloudCredentials, GoogleCloudCredentials} import is.hail.utils.{FastSeq, SerializableHadoopConfiguration} -import org.apache.hadoop.conf.Configuration import java.io.Serializable import java.nio.file.Path +import org.apache.hadoop.conf.Configuration + object RouterFSURL { def apply(fs: FS)(_url: fs.URL): RouterFSURL = RouterFSURL(_url, fs) } diff --git a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala index 98018a8c389..ad5f97e1abc 100644 --- a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala @@ -3,14 +3,15 @@ package is.hail.io.fs import is.hail.services.oauth2.AzureCloudCredentials import is.hail.shadedazure.com.azure.storage.blob.BlobServiceClient import is.hail.utils._ + +import scala.collection.mutable + import org.apache.http.client.methods.HttpPost import org.apache.http.client.utils.URIBuilder import org.apache.http.impl.client.HttpClients import org.apache.http.util.EntityUtils -import org.json4s.jackson.JsonMethods import org.json4s.{DefaultFormats, Formats} - -import scala.collection.mutable +import org.json4s.jackson.JsonMethods object TerraAzureStorageFS { private val TEN_MINUTES_IN_MS = 10 * 60 * 1000 diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 5d458900484..dd17671eeac 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -3,15 +3,17 @@ package is.hail.services import is.hail.expr.ir.ByteArrayBuilder import is.hail.services.requests.{BatchServiceRequester, Requester} import is.hail.utils._ + +import scala.util.Random + +import java.nio.charset.StandardCharsets +import java.nio.file.Path + import org.apache.http.entity.ByteArrayEntity import org.apache.http.entity.ContentType.APPLICATION_JSON +import org.json4s.{CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JObject, JString} import org.json4s.JsonAST.{JArray, JBool} import org.json4s.jackson.JsonMethods -import org.json4s.{CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JObject, JString} - -import java.nio.charset.StandardCharsets -import java.nio.file.Path -import scala.util.Random case class BatchRequest( billing_project: String, @@ -87,7 +89,7 @@ object JobGroupStates { object BatchClient { def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env) - : BatchClient = + : BatchClient = new BatchClient(BatchServiceRequester(deployConfig, credentialsFile, env)) } diff --git a/hail/src/main/scala/is/hail/services/BatchConfig.scala b/hail/src/main/scala/is/hail/services/BatchConfig.scala index 3cd8e3b0c62..da15265b849 100644 --- a/hail/src/main/scala/is/hail/services/BatchConfig.scala +++ b/hail/src/main/scala/is/hail/services/BatchConfig.scala @@ -1,11 +1,12 @@ package is.hail.services import is.hail.utils._ -import org.json4s._ -import org.json4s.jackson.JsonMethods import java.nio.file.{Files, Path} +import org.json4s._ +import org.json4s.jackson.JsonMethods + object BatchConfig { def fromConfigFile(file: Path): Option[BatchConfig] = if (!file.toFile.exists()) None diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala index 657521600ce..06c26c653a4 100644 --- a/hail/src/main/scala/is/hail/services/oauth2.scala +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -1,6 +1,5 @@ package is.hail.services -import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials} import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials import is.hail.shadedazure.com.azure.core.credential.{TokenCredential, TokenRequestContext} @@ -8,12 +7,15 @@ import is.hail.shadedazure.com.azure.identity.{ ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, } import is.hail.utils.{defaultJSONFormats, using} -import org.json4s.Formats -import org.json4s.jackson.JsonMethods + +import scala.collection.JavaConverters._ import java.io.Serializable import java.nio.file.{Files, Path} -import scala.collection.JavaConverters._ + +import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials} +import org.json4s.Formats +import org.json4s.jackson.JsonMethods object oauth2 { diff --git a/hail/src/main/scala/is/hail/services/package.scala b/hail/src/main/scala/is/hail/services/package.scala index c0369af14b6..e7887fea509 100644 --- a/hail/src/main/scala/is/hail/services/package.scala +++ b/hail/src/main/scala/is/hail/services/package.scala @@ -1,16 +1,17 @@ package is.hail +import is.hail.services.requests.ClientResponseException import is.hail.shadedazure.com.azure.storage.common.implementation.Constants import is.hail.utils._ import scala.util.Random + import java.io._ import java.net._ + import com.google.api.client.googleapis.json.GoogleJsonResponseException import com.google.api.client.http.HttpResponseException import com.google.cloud.storage.StorageException -import is.hail.services.requests.ClientResponseException - import javax.net.ssl.SSLException import org.apache.http.{ConnectionClosedException, NoHttpResponseException} import org.apache.http.conn.HttpHostConnectException diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala index 05a13ed057c..855c219b631 100644 --- a/hail/src/main/scala/is/hail/services/requests.scala +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -2,20 +2,21 @@ package is.hail.services import is.hail.services.oauth2.{CloudCredentials, CloudScopes} import is.hail.utils.{log, _} + +import java.net.URL +import java.nio.file.Path + +import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} import org.apache.http.client.config.RequestConfig import org.apache.http.client.methods.{HttpGet, HttpPatch, HttpPost, HttpUriRequest} import org.apache.http.entity.ContentType.APPLICATION_JSON import org.apache.http.entity.StringEntity import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} import org.apache.http.util.EntityUtils -import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} import org.json4s.JValue import org.json4s.JsonAST.JNothing import org.json4s.jackson.JsonMethods -import java.net.URL -import java.nio.file.Path - object requests { class ClientResponseException(val status: Int, message: String) extends Exception(message) From 9b4667339c6c71d2676b6ff6c9b632d95e5befa3 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 17 Sep 2024 11:46:57 -0400 Subject: [PATCH 05/20] `stop` in python, closable only in scala --- hail/python/hail/backend/py4j_backend.py | 3 ++- hail/python/hail/context.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 4e216e22aad..2c1610b5614 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -187,7 +187,7 @@ def decode_bytearray(encoded): self._jbackend = jbackend self._jhc = jhc - self._backend_server = self._hail_package.backend.BackendServer.apply(self._jbackend) + self._backend_server = self._hail_package.backend.BackendServer(self._jbackend) self._backend_server_port: int = self._backend_server.port() self._backend_server.start() self._requests_session = requests.Session() @@ -307,6 +307,7 @@ def _to_java_blockmatrix_ir(self, ir): def stop(self): self._backend_server.close() + self._jbackend.close() self._jhc.stop() self._jhc = None self._registered_ir_function_names = set() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 56897eec987..5258f27fbc1 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -151,7 +151,7 @@ def default_reference(self, value): def stop(self): assert self._backend - self._backend.close() + self._backend.stop() self._backend = None Env._hc = None Env._dummy_table = None From 9d6572e93e013b7da0f6946db3b1604915300a5a Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 21 Oct 2024 11:24:06 -0400 Subject: [PATCH 06/20] credentials are scoped --- .../hail/backend/service/ServiceBackend.scala | 2 +- .../scala/is/hail/io/fs/AzureStorageFS.scala | 4 + .../scala/is/hail/io/fs/GoogleStorageFS.scala | 3 + .../main/scala/is/hail/io/fs/RouterFS.scala | 15 ++- .../is/hail/io/fs/TerraAzureStorageFS.scala | 6 +- .../scala/is/hail/services/BatchClient.scala | 26 ++++- .../main/scala/is/hail/services/oauth2.scala | 94 +++++++++++-------- .../scala/is/hail/services/requests.scala | 15 +-- .../is/hail/io/fs/AzureStorageFSSuite.scala | 3 +- .../is/hail/io/fs/GoogleStorageFSSuite.scala | 3 +- .../is/hail/services/BatchClientSuite.scala | 12 ++- 11 files changed, 116 insertions(+), 67 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index c4bb2d40388..5a00bb76776 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -212,7 +212,7 @@ class ServiceBackend( in_update_job_group_id = jobGroup.job_group_id, in_update_parent_ids = Array(), process = JvmJob( - command = Array(Main.WORKER, root, s"${jobGroup.job_group_id}", s"$n"), + command = Array(Main.WORKER, root, s"$i", s"$n"), jar_url = jarLocation, profile = flags.get("profile") != null, ), diff --git a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala index a640fc68419..2e7eb2b8672 100644 --- a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala @@ -12,6 +12,7 @@ import is.hail.shadedazure.com.azure.storage.blob.models.{ BlobItem, BlobRange, BlobStorageException, ListBlobsOptions, } import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient +import is.hail.utils.FastSeq import scala.collection.JavaConverters._ import scala.collection.mutable @@ -58,6 +59,9 @@ object AzureStorageFS { private val AZURE_HTTPS_URI_REGEX = "^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r + val RequiredOAuthScopes: IndexedSeq[String] = + FastSeq("https://storage.azure.com/.default") + def parseUrl(filename: String): AzureStorageFSURL = { AZURE_HTTPS_URI_REGEX .findFirstMatchIn(filename) diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index 6a2921f07de..dfcb81d30ec 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -44,6 +44,9 @@ object GoogleStorageFS { private[this] val GCS_URI_REGEX = "^gs:\\/\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r + val RequiredOAuthScopes: IndexedSeq[String] = + FastSeq("https://www.googleapis.com/auth/devstorage.read_write") + def parseUrl(filename: String): GoogleStorageFSURL = { val scheme = filename.split(":")(0) if (scheme == null || scheme != "gs") { diff --git a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala index d89985c0780..2a7ab03fe37 100644 --- a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala @@ -52,13 +52,18 @@ object RouterFS { def buildRoutes(cloudConfig: CloudStorageFSConfig, env: Map[String, String] = sys.env): FS = new RouterFS( IndexedSeq.concat( - cloudConfig.google.map { case GoogleStorageFSConfig(path, maybeRequesterPaysConfig) => - new GoogleStorageFS(GoogleCloudCredentials(path), maybeRequesterPaysConfig) + cloudConfig.google.map { case GoogleStorageFSConfig(path, mRPConfig) => + new GoogleStorageFS( + GoogleCloudCredentials(path, GoogleStorageFS.RequiredOAuthScopes, env), + mRPConfig, + ) }, cloudConfig.azure.map { case AzureStorageFSConfig(path) => - val cred = AzureCloudCredentials(path) - if (env.contains("HAIL_TERRA")) new TerraAzureStorageFS(cred) - else new AzureStorageFS(cred) + if (env.contains("HAIL_TERRA")) { + val creds = AzureCloudCredentials(path, TerraAzureStorageFS.RequiredOAuthScopes, env) + new TerraAzureStorageFS(creds) + } else + new AzureStorageFS(AzureCloudCredentials(path, AzureStorageFS.RequiredOAuthScopes, env)) }, FastSeq(new HadoopFS(new SerializableHadoopConfiguration(new Configuration()))), ) diff --git a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala index ad5f97e1abc..a073e9a0b38 100644 --- a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala @@ -15,6 +15,9 @@ import org.json4s.jackson.JsonMethods object TerraAzureStorageFS { private val TEN_MINUTES_IN_MS = 10 * 60 * 1000 + + val RequiredOAuthScopes: IndexedSeq[String] = + FastSeq("https://management.azure.com/.default") } class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorageFS(credential) { @@ -55,8 +58,7 @@ class TerraAzureStorageFS(credential: AzureCloudCredentials) extends AzureStorag val url = s"$workspaceManagerUrl/api/workspaces/v1/$workspaceId/resources/controlled/azure/storageContainer/$containerResourceId/getSasToken" val req = new HttpPost(url) - val token = credential.accessToken(FastSeq("https://management.azure.com/.default")) - req.addHeader("Authorization", s"Bearer $token") + req.addHeader("Authorization", s"Bearer ${credential.accessToken}") val tenHoursInSeconds = 10 * 3600 val expiration = System.currentTimeMillis() + tenHoursInSeconds * 1000 diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index dd17671eeac..93b85910adf 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -1,11 +1,13 @@ package is.hail.services import is.hail.expr.ir.ByteArrayBuilder -import is.hail.services.requests.{BatchServiceRequester, Requester} +import is.hail.services.oauth2.CloudCredentials +import is.hail.services.requests.Requester import is.hail.utils._ import scala.util.Random +import java.net.URL import java.nio.charset.StandardCharsets import java.nio.file.Path @@ -88,9 +90,29 @@ object JobGroupStates { } object BatchClient { + + private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] = + env.get("HAIL_CLOUD") match { + case Some("gcp") => + Array( + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + "openid", + ) + case Some("azure") => + env.get("HAIL_AZURE_OAUTH_SCOPE").toArray + case Some(cloud) => + throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.") + case None => + throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") + } + def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env) : BatchClient = - new BatchClient(BatchServiceRequester(deployConfig, credentialsFile, env)) + new BatchClient(Requester( + new URL(deployConfig.baseUrl("batch")), + CloudCredentials(credentialsFile, BatchServiceScopes(env), env), + )) } case class BatchClient private (req: Requester) extends Logging with AutoCloseable { diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala index 06c26c653a4..7bce6ac70d2 100644 --- a/hail/src/main/scala/is/hail/services/oauth2.scala +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -2,7 +2,9 @@ package is.hail.services import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials -import is.hail.shadedazure.com.azure.core.credential.{TokenCredential, TokenRequestContext} +import is.hail.shadedazure.com.azure.core.credential.{ + AccessToken, TokenCredential, TokenRequestContext, +} import is.hail.shadedazure.com.azure.identity.{ ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, } @@ -12,6 +14,7 @@ import scala.collection.JavaConverters._ import java.io.Serializable import java.nio.file.{Files, Path} +import java.time.OffsetDateTime import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials} import org.json4s.Formats @@ -20,38 +23,25 @@ import org.json4s.jackson.JsonMethods object oauth2 { sealed trait CloudCredentials extends Product with Serializable { - def accessToken(scopes: IndexedSeq[String]): String + def accessToken: String } - def CloudCredentials(credentialsPath: Path, env: Map[String, String] = sys.env) - : CloudCredentials = + def CloudCredentials( + keyPath: Path, + scopes: IndexedSeq[String], + env: Map[String, String] = sys.env, + ): CloudCredentials = env.get("HAIL_CLOUD") match { - case Some("gcp") => GoogleCloudCredentials(Some(credentialsPath)) - case Some("azure") => AzureCloudCredentials(Some(credentialsPath)) + case Some("gcp") => GoogleCloudCredentials(Some(keyPath), scopes, env) + case Some("azure") => AzureCloudCredentials(Some(keyPath), scopes, env) case Some(cloud) => throw new IllegalArgumentException(s"Unknown cloud: '$cloud'") case None => throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") } - def CloudScopes(env: Map[String, String] = sys.env): Array[String] = - env.get("HAIL_CLOUD") match { - case Some("gcp") => - Array( - "https://www.googleapis.com/auth/userinfo.profile", - "https://www.googleapis.com/auth/userinfo.email", - "openid", - ) - case Some("azure") => - sys.env.get("HAIL_AZURE_OAUTH_SCOPE").toArray - case Some(cloud) => - throw new IllegalArgumentException(s"Unknown cloud: '$cloud'.") - case None => - throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") - } - case class GoogleCloudCredentials(value: GoogleCredentials) extends CloudCredentials { - override def accessToken(scopes: IndexedSeq[String]): String = { + override def accessToken: String = { value.refreshIfExpired() - value.createScoped(scopes.asJava).getAccessToken.getTokenValue + value.getAccessToken.getTokenValue } } @@ -60,20 +50,45 @@ object oauth2 { val GoogleApplicationCredentials = "GOOGLE_APPLICATION_CREDENTIALS" } - def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): GoogleCloudCredentials = - GoogleCloudCredentials( - keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match { - case Some(path) => using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream) - case None => GoogleCredentials.getApplicationDefault - } - ) + def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env) + : GoogleCloudCredentials = + GoogleCloudCredentials { + val creds: GoogleCredentials = + keyPath.orElse(env.get(GoogleApplicationCredentials).map(Path.of(_))) match { + case Some(path) => + using(Files.newInputStream(path))(ServiceAccountCredentials.fromStream) + case None => + GoogleCredentials.getApplicationDefault + } + + creds.createScoped(scopes: _*) + } } sealed trait AzureCloudCredentials extends CloudCredentials { + def value: TokenCredential + def scopes: IndexedSeq[String] + + @transient private[this] var token: AccessToken = _ + + override def accessToken: String = { + refreshIfRequired() + token.getToken + } + + private[this] def refreshIfRequired(): Unit = + if (!isExpired) token.getToken + else synchronized { + if (isExpired) { + token = value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava)) + } + + token.getToken + } - override def accessToken(scopes: IndexedSeq[String]): String = - value.getTokenSync(new TokenRequestContext().setScopes(scopes.asJava)).getToken + private[this] def isExpired: Boolean = + token == null || OffsetDateTime.now.plusHours(1).isBefore(token.getExpiresAt) } object AzureCloudCredentials { @@ -81,19 +96,22 @@ object oauth2 { val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS" } - def apply(keyPath: Option[Path], env: Map[String, String] = sys.env): AzureCloudCredentials = + def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env) + : AzureCloudCredentials = keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match { - case Some(path) => AzureClientSecretCredentials(path) - case None => AzureDefaultCredentials + case Some(path) => AzureClientSecretCredentials(path, scopes) + case None => AzureDefaultCredentials(scopes) } } - private case object AzureDefaultCredentials extends AzureCloudCredentials { + private case class AzureDefaultCredentials(scopes: IndexedSeq[String]) + extends AzureCloudCredentials { @transient override lazy val value: TokenCredential = new DefaultAzureCredentialBuilder().build() } - private case class AzureClientSecretCredentials(path: Path) extends AzureCloudCredentials { + private case class AzureClientSecretCredentials(path: Path, scopes: IndexedSeq[String]) + extends AzureCloudCredentials { @transient override lazy val value: TokenCredential = using(Files.newInputStream(path)) { is => implicit val fmts: Formats = defaultJSONFormats diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala index 855c219b631..ed269b606e3 100644 --- a/hail/src/main/scala/is/hail/services/requests.scala +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -1,10 +1,9 @@ package is.hail.services -import is.hail.services.oauth2.{CloudCredentials, CloudScopes} +import is.hail.services.oauth2.CloudCredentials import is.hail.utils.{log, _} import java.net.URL -import java.nio.file.Path import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} import org.apache.http.client.config.RequestConfig @@ -30,15 +29,7 @@ object requests { private[this] val TIMEOUT_MS = 5 * 1000 - def BatchServiceRequester(conf: DeployConfig, keyFile: Path, env: Map[String, String] = sys.env) - : Requester = - Requester( - new URL(conf.baseUrl("batch")), - CloudCredentials(keyFile, env), - CloudScopes(env), - ) - - def Requester(baseUrl: URL, cred: CloudCredentials, scopes: IndexedSeq[String]): Requester = { + def Requester(baseUrl: URL, cred: CloudCredentials): Requester = { val httpClient: CloseableHttpClient = { log.info("creating HttpClient") @@ -67,7 +58,7 @@ object requests { def request(req: HttpUriRequest, body: Option[HttpEntity] = None): JValue = { log.info(s"request ${req.getMethod} ${req.getURI}") - req.addHeader("Authorization", s"Bearer ${cred.accessToken(scopes)}") + req.addHeader("Authorization", s"Bearer ${cred.accessToken}") body.foreach(entity => req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(entity)) retryTransientErrors { using(httpClient.execute(req)) { resp => diff --git a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala index edca1194921..ce3759d0062 100644 --- a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala @@ -16,7 +16,8 @@ class AzureStorageFSSuite extends FSSuite { } } - lazy val fs = new AzureStorageFS(AzureCloudCredentials(None)) + override lazy val fs: FS = + new AzureStorageFS(AzureCloudCredentials(None, AzureStorageFS.RequiredOAuthScopes)) @Test def testMakeQualified(): Unit = { val qualifiedFileName = "https://account.blob.core.windows.net/container/path" diff --git a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala index dcd019d2e94..5bcf8a7fc6e 100644 --- a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala @@ -17,7 +17,8 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite { } } - lazy val fs = new GoogleStorageFS(GoogleCloudCredentials(None), None) + override lazy val fs: FS = + new GoogleStorageFS(GoogleCloudCredentials(None, GoogleStorageFS.RequiredOAuthScopes), None) @Test def testMakeQualified(): Unit = { val qualifiedFileName = "gs://bucket/path" diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index af1aaeced4d..4fe1ba353bf 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -6,29 +6,31 @@ import java.nio.file.Path import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test +import sourcecode.FullName class BatchClientSuite extends TestNGSuite { @Test def testBasic(): Unit = - using(BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json"))) { client => + using(BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json"))) { client => val jobGroup = client.run( BatchRequest( billing_project = "test", - n_jobs = 1, + n_jobs = 0, token = tokenUrlSafe, + attributes = Map("name" -> s"Test ${implicitly[FullName].value}"), ), JobGroupRequest( - job_group_id = 0, + job_group_id = 1, absolute_parent_id = 0, ), FastSeq( JobRequest( - job_id = 0, + job_id = 1, always_run = false, in_update_job_group_id = 0, in_update_parent_ids = Array(), process = BashJob( image = "ubuntu:22.04", - command = Array("/bin/bash", "-c", "'hello, world!"), + command = Array("/bin/bash", "-c", "echo 'hello, hail!'"), ), ) ), From 03d5cfbd8d7c19d2a4099fb6ef4862a4b00f455d Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 16:43:21 -0400 Subject: [PATCH 07/20] rewrite, simplify and test BatchClient --- .../hail/backend/service/ServiceBackend.scala | 68 ++--- .../is/hail/backend/service/Worker.scala | 6 +- .../scala/is/hail/services/BatchClient.scala | 260 +++++++++--------- .../scala/is/hail/services/BatchConfig.scala | 5 +- .../scala/is/hail/services/requests.scala | 2 +- .../is/hail/backend/ServiceBackendSuite.scala | 40 +-- .../is/hail/services/BatchClientSuite.scala | 105 +++++-- 7 files changed, 254 insertions(+), 232 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 5a00bb76776..2506dfcab53 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -14,7 +14,7 @@ import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix -import is.hail.services.{BatchClient, _} +import is.hail.services.{BatchClient, JobGroupRequest, _} import is.hail.services.JobGroupStates.Failure import is.hail.types._ import is.hail.types.physical._ @@ -53,12 +53,11 @@ object ServiceBackend { private val log = Logger.getLogger(getClass.getName()) def apply( - jarLocation: String, + jarSpec: JarSpec, name: String, theHailClassLoader: HailClassLoader, batchClient: BatchClient, - batchId: Option[Int], - jobGroupId: Option[Int], + batchConfig: BatchConfig, scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), rpcConfig: ServiceBackendRPCPayload, env: Map[String, String], @@ -87,12 +86,11 @@ object ServiceBackend { ) val backend = new ServiceBackend( - jarLocation, + jarSpec, name, theHailClassLoader, batchClient, - batchId, - jobGroupId, + batchConfig, flags, rpcConfig.tmp_dir, fs, @@ -116,12 +114,11 @@ object ServiceBackend { } class ServiceBackend( - val jarLocation: String, + val jarSpec: JarSpec, var name: String, val theHailClassLoader: HailClassLoader, val batchClient: BatchClient, - val curBatchId: Option[Int], - val curJobGroupId: Option[Int], + val batchConfig: BatchConfig, val flags: HailFeatureFlags, val tmpdir: String, val fs: FS, @@ -197,23 +194,12 @@ class ServiceBackend( } ) - val jobGroup = JobGroupRequest( - job_group_id = 1, // QoB creates an update for every new stage - absolute_parent_id = curJobGroupId.getOrElse(0), - attributes = Map("name" -> stageIdentifier), - ) - - log.info(s"worker job group spec: $jobGroup") - val jobs = collection.indices.map { i => JobRequest( - job_id = i + 1, always_run = false, - in_update_job_group_id = jobGroup.job_group_id, - in_update_parent_ids = Array(), process = JvmJob( command = Array(Main.WORKER, root, s"$i", s"$n"), - jar_url = jarLocation, + spec = jarSpec, profile = flags.get("profile") != null, ), resources = Some( @@ -235,23 +221,22 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: running job") - val batchId = curBatchId.getOrElse { - batchClient.newBatch( - BatchRequest( - billing_project = backendContext.billingProject, - n_jobs = 0, - token = token, - attributes = Map("name" -> name), - ) + val jobGroupId = batchClient.newJobGroup( + JobGroupRequest( + batch_id = batchConfig.batchId, + absolute_parent_id = batchConfig.jobGroupId, + token = tokenUrlSafe, + attributes = Map("name" -> stageIdentifier), + jobs = jobs, ) - } + ) - val (updateId, jobGroupId) = batchClient.newJobGroup(batchId, token, jobGroup, jobs) - val response = batchClient.waitForJobGroup(batchId, jobGroupId) + Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms + val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) stageCount += 1 if (response.state == Failure) { - throw new HailBatchFailure(s"Update $updateId for batch $batchId failed") + throw new HailBatchFailure(s"JobGroup $jobGroupId for batch ${batchConfig.batchId} failed") } (token, root, n) @@ -412,13 +397,17 @@ object ServiceBackendAPI { val scratchDir = argv(0) // val logFile = argv(1) - val jarLocation = argv(2) + val jarSpecStr = argv(2) val kind = argv(3) assert(kind == Main.DRIVER) val name = argv(4) val inputURL = argv(5) val outputURL = argv(6) + implicit val formats: Formats = DefaultFormats + JarSpecFormats + + val jarGitRevision = JsonMethods.parse(jarSpecStr).extract[JarSpec] + val fs = RouterFS.buildRoutes( CloudStorageFSConfig.fromFlagsAndEnv( Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), @@ -434,23 +423,18 @@ object ServiceBackendAPI { val batchConfig = BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) - val batchId = batchConfig.map(_.batchId) - val jobGroupId = batchConfig.map(_.jobGroupId) log.info("BatchConfig parsed.") - implicit val formats: Formats = DefaultFormats - val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] // FIXME: when can the classloader be shared? (optimizer benefits!) val backend = ServiceBackend( - jarLocation, + jarGitRevision, name, new HailClassLoader(getClass().getClassLoader()), batchClient, - batchId, - jobGroupId, + batchConfig, scratchDir, rpcConfig, sys.env, diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index de608b26515..5722e3bc435 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -172,8 +172,7 @@ object Worker { null, new HailClassLoader(getClass().getClassLoader()), null, - None, - None, + null, null, null, null, @@ -188,8 +187,7 @@ object Worker { null, new HailClassLoader(getClass().getClassLoader()), null, - None, - None, + null, null, null, null, diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 93b85910adf..a79dc7fcead 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -1,6 +1,7 @@ package is.hail.services import is.hail.expr.ir.ByteArrayBuilder +import is.hail.services.BatchClient.BunchMaxSizeBytes import is.hail.services.oauth2.CloudCredentials import is.hail.services.requests.Requester import is.hail.utils._ @@ -8,7 +9,7 @@ import is.hail.utils._ import scala.util.Random import java.net.URL -import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Path import org.apache.http.entity.ByteArrayEntity @@ -19,41 +20,54 @@ import org.json4s.jackson.JsonMethods case class BatchRequest( billing_project: String, - n_jobs: Int, token: String, + n_jobs: Int, attributes: Map[String, String] = Map.empty, ) case class JobGroupRequest( - job_group_id: Int, + batch_id: Int, absolute_parent_id: Int, + token: String, attributes: Map[String, String] = Map.empty, + jobs: IndexedSeq[JobRequest] = FastSeq(), ) case class JobRequest( - job_id: Int, always_run: Boolean, - in_update_job_group_id: Int, - in_update_parent_ids: Array[Int], process: JobProcess, + attributes: Map[String, String] = Map.empty, + cloudfuse: Option[Array[CloudfuseConfig]] = None, resources: Option[JobResources] = None, regions: Option[Array[String]] = None, - cloudfuse: Option[Array[CloudfuseConfig]] = None, - attributes: Map[String, String] = Map.empty, ) sealed trait JobProcess - -case class BashJob( - image: String, - command: Array[String], -) extends JobProcess - -case class JvmJob( - command: Array[String], - jar_url: String, - profile: Boolean, -) extends JobProcess +case class BashJob(image: String, command: Array[String]) extends JobProcess +case class JvmJob(command: Array[String], spec: JarSpec, profile: Boolean) extends JobProcess + +sealed trait JarSpec +case class GitRevision(sha: String) extends JarSpec +case class JarUrl(url: String) extends JarSpec + +object JarSpecFormats extends CustomSerializer[JarSpec](implicit fmts => + ( + { + case obj: JObject => + val value = (obj \ "value").extract[String] + (obj \ "type").extract[String] match { + case "jar_url" => JarUrl(value) + case "git_revision" => GitRevision(value) + } + }, + { + case JarUrl(url) => + JObject("type" -> JString("jar_url"), "value" -> JString(url)) + case GitRevision(sha) => + JObject("type" -> JString("git_revision"), "value" -> JString(sha)) + }, + ) + ) case class JobResources( preemptible: Boolean, @@ -113,6 +127,8 @@ object BatchClient { new URL(deployConfig.baseUrl("batch")), CloudCredentials(credentialsFile, BatchServiceScopes(env), env), )) + + private val BunchMaxSizeBytes: Int = 1024 * 1024 } case class BatchClient private (req: Requester) extends Logging with AutoCloseable { @@ -121,99 +137,43 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab DefaultFormats + JobProcessRequestSerializer + JobGroupStateDeserializer + - JobGroupResponseDeserializer + JobGroupResponseDeserializer + + JarSpecFormats def newBatch(createRequest: BatchRequest): Int = { val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest)) val batchId = (response \ "id").extract[Int] - log.info(s"run: created batch $batchId") + log.info(s"Created batch $batchId") batchId } - def newJobGroup( - batchId: Int, - token: String, - jobGroup: JobGroupRequest, - jobs: IndexedSeq[JobRequest], - ): (Int, Int) = { - - val updateJson = JObject( - "n_jobs" -> JInt(jobs.length), - "n_job_groups" -> JInt(1), - "token" -> JString(token), - ) - - val jobGroupSpec = getJsonBytes(jobGroup) - val jobBunches = createBunches(jobs) - val updateIDAndJobGroupId = - if (jobBunches.length == 1 && jobBunches(0).length + jobGroupSpec.length < 1024 * 1024) { - val b = new ByteArrayBuilder() - b ++= "{\"job_groups\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, Array(jobGroupSpec)) - b ++= ",\"bunch\":".getBytes(StandardCharsets.UTF_8) - addBunchBytes(b, jobBunches(0)) - b ++= ",\"update\":".getBytes(StandardCharsets.UTF_8) - b ++= JsonMethods.compact(updateJson).getBytes(StandardCharsets.UTF_8) - b += '}' - val data = b.result() - val resp = req.post( - s"/api/v1alpha/batches/$batchId/update-fast", - new ByteArrayEntity(data, APPLICATION_JSON), - ) - b.clear() - ((resp \ "update_id").extract[Int], (resp \ "start_job_group_id").extract[Int]) - } else { - val resp = req.post(s"/api/v1alpha/batches/$batchId/updates/create", updateJson) - val updateID = (resp \ "update_id").extract[Int] - val startJobGroupId = (resp \ "start_job_group_id").extract[Int] - - val b = new ByteArrayBuilder() - b ++= "[".getBytes(StandardCharsets.UTF_8) - b ++= jobGroupSpec - b ++= "]".getBytes(StandardCharsets.UTF_8) - req.post( - s"/api/v1alpha/batches/$batchId/updates/$updateID/job-groups/create", - new ByteArrayEntity(b.result(), APPLICATION_JSON), - ) + def newJobGroup(req: JobGroupRequest): Int = { + val nJobs = req.jobs.length + val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs) + log.info(s"Began update '$updateId' for batch '${req.batch_id}'.") - b.clear() - var i = 0 - while (i < jobBunches.length) { - addBunchBytes(b, jobBunches(i)) - val data = b.result() - req.post( - s"/api/v1alpha/batches/$batchId/updates/$updateID/jobs/create", - new ByteArrayEntity(data, APPLICATION_JSON), - ) - b.clear() - i += 1 - } + createJobGroup(updateId, req) + log.info(s"Created job group $startJobGroupId for batch ${req.batch_id}") - req.patch(s"/api/v1alpha/batches/$b/updates/$updateID/commit") - (updateID, startJobGroupId) - } + createJobsIncremental(req.batch_id, updateId, req.jobs) + log.info(s"Submitted $nJobs in job group $startJobGroupId for batch ${req.batch_id}") - log.info(s"run: created update $updateIDAndJobGroupId for batch $batchId") - updateIDAndJobGroupId - } + commitUpdate(req.batch_id, updateId) + log.info(s"Committed update $updateId for batch ${req.batch_id}.") - def run(batchRequest: BatchRequest, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) - : JobGroupResponse = { - val batchID = newBatch(batchRequest) - val (_, jobGroupId) = newJobGroup(batchID, batchRequest.token, jobGroup, jobs) - waitForJobGroup(batchID, jobGroupId) + startJobGroupId } - def waitForJobGroup(batchID: Int, jobGroupId: Int): JobGroupResponse = { - - Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms + def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = + req + .get(s"/api/v1alpha/batches/$batchId/job-groups/$jobGroupId") + .extract[JobGroupResponse] + def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = { val start = System.nanoTime() while (true) { - val jobGroup = req - .get(s"/api/v1alpha/batches/$batchID/job-groups/$jobGroupId") - .extract[JobGroupResponse] + val jobGroup = getJobGroup(batchId, jobGroupId) if (jobGroup.complete) return jobGroup @@ -236,47 +196,85 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab throw new AssertionError("unreachable") } - private def createBunches(jobs: IndexedSeq[JobRequest]): BoxedArrayBuilder[Array[Array[Byte]]] = { - val bunches = new BoxedArrayBuilder[Array[Array[Byte]]]() - val bunchb = new BoxedArrayBuilder[Array[Byte]]() - - var i = 0 - var size = 0 - while (i < jobs.length) { - val jobBytes = getJsonBytes(jobs(i)) - if (size + jobBytes.length > 1024 * 1024) { - bunches += bunchb.result() - bunchb.clear() - size = 0 + override def close(): Unit = + req.close() + + private[this] def createJobsIncremental( + batchId: Int, + updateId: Int, + jobs: IndexedSeq[JobRequest], + ): Unit = { + val buff = new ByteArrayBuilder(BunchMaxSizeBytes) + var sym = "[" + + def flush(): Unit = { + buff ++= "]".getBytes(UTF_8) + req.post( + s"/api/v1alpha/batches/$batchId/updates/$updateId/jobs/create", + new ByteArrayEntity(buff.result(), APPLICATION_JSON), + ) + buff.clear() + sym = "[" + } + + for ((job, idx) <- jobs.zipWithIndex) { + val jobPayload = jobToJson(job, idx).getBytes(UTF_8) + + if (buff.size + jobPayload.length > BunchMaxSizeBytes) { + flush() } - bunchb += jobBytes - size += jobBytes.length - i += 1 + + buff ++= sym.getBytes(UTF_8) + buff ++= jobPayload + sym = "," } - assert(bunchb.size > 0) - bunches += bunchb.result() - bunchb.clear() - bunches + if (buff.size > 0) { flush() } } - private def getJsonBytes(obj: Any): Array[Byte] = - JsonMethods.compact(Extraction.decompose(obj)).getBytes(StandardCharsets.UTF_8) - - private def addBunchBytes(b: ByteArrayBuilder, bunch: Array[Array[Byte]]): Unit = { - var j = 0 - b += '[' - while (j < bunch.length) { - if (j > 0) - b += ',' - b ++= bunch(j) - j += 1 + private[this] def jobToJson(j: JobRequest, jobIdx: Int): String = + JsonMethods.compact { + Extraction.decompose(j) + .asInstanceOf[JObject] + .merge( + JObject( + "job_id" -> JInt(jobIdx), + "in_update_job_group_id" -> JInt(1), + ) + ) } - b += ']' - } - override def close(): Unit = - req.close() + private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) = + req + .post( + s"/api/v1alpha/batches/$batchId/updates/create", + JObject( + "token" -> JString(token), + "n_jobs" -> JInt(nJobs), + "n_job_groups" -> JInt(1), + ), + ) + .as { case obj: JObject => + ( + (obj \ "update_id").extract[Int], + (obj \ "start_job_group_id").extract[Int], + ) + } + + private[this] def commitUpdate(batchId: Int, updateId: Int): Unit = + req.patch(s"/api/v1alpha/batches/$batchId/updates/$updateId/commit") + + private[this] def createJobGroup(updateId: Int, jobGroup: JobGroupRequest): Unit = + req.post( + s"/api/v1alpha/batches/${jobGroup.batch_id}/updates/$updateId/job-groups/create", + JArray(List( + JObject( + "job_group_id" -> JInt(1), // job group id relative to the update + "absolute_parent_id" -> JInt(jobGroup.absolute_parent_id), + "attributes" -> Extraction.decompose(jobGroup.attributes), + ) + )), + ) private[this] object JobProcessRequestSerializer extends CustomSerializer[JobProcess](_ => @@ -289,11 +287,11 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab "image" -> JString(image), "command" -> JArray(command.map(JString).toList), ) - case JvmJob(command, url, profile) => + case JvmJob(command, jarSpec, profile) => JObject( "type" -> JString("jvm"), "command" -> JArray(command.map(JString).toList), - "jar_spec" -> JObject("type" -> JString("jar_url"), "value" -> JString(url)), + "jar_spec" -> Extraction.decompose(jarSpec), "profile" -> JBool(profile), ) }, diff --git a/hail/src/main/scala/is/hail/services/BatchConfig.scala b/hail/src/main/scala/is/hail/services/BatchConfig.scala index da15265b849..6d5b08ed7b5 100644 --- a/hail/src/main/scala/is/hail/services/BatchConfig.scala +++ b/hail/src/main/scala/is/hail/services/BatchConfig.scala @@ -8,9 +8,8 @@ import org.json4s._ import org.json4s.jackson.JsonMethods object BatchConfig { - def fromConfigFile(file: Path): Option[BatchConfig] = - if (!file.toFile.exists()) None - else using(Files.newInputStream(file))(in => Some(fromConfig(JsonMethods.parse(in)))) + def fromConfigFile(file: Path): BatchConfig = + using(Files.newInputStream(file))(in => fromConfig(JsonMethods.parse(in))) def fromConfig(config: JValue): BatchConfig = { implicit val formats: Formats = DefaultFormats diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala index ed269b606e3..07ac0806dff 100644 --- a/hail/src/main/scala/is/hail/services/requests.scala +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -64,7 +64,7 @@ object requests { using(httpClient.execute(req)) { resp => val statusCode = resp.getStatusLine.getStatusCode log.info(s"request ${req.getMethod} ${req.getURI} response $statusCode") - val message = Option(resp.getEntity).map(EntityUtils.toString) + val message = Option(resp.getEntity).map(EntityUtils.toString).filter(_.nonEmpty) if (statusCode < 200 || statusCode >= 300) { throw new ClientResponseException(statusCode, message.orNull) } diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 57b17a1c59c..c6aaa709f20 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -29,31 +29,16 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV using(ServiceBackend(batchClient, rpcConfig)) { backend => val contexts = Array.tabulate(1)(_.toString.getBytes) - // verify that the service backend - // - creates the batch with the correct billing project, and + // verify that // - the number of jobs matches the number of partitions, and // - each job is created in the specified region, and // - each job's resource configuration matches the rpc config - val batchId = Random.nextInt() - - when(batchClient.newBatch(any[BatchRequest])) thenAnswer { - (batchRequest: BatchRequest) => - batchRequest.billing_project shouldEqual rpcConfig.billing_project - batchRequest.n_jobs shouldBe 0 - batchRequest.attributes.get("name").value shouldBe backend.name - batchId - } - when(batchClient.newJobGroup( - any[Int], - any[String], - any[JobGroupRequest], - any[IndexedSeq[JobRequest]], - )) thenAnswer { - (id: Int, _: String, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) => - id shouldBe batchId - jobGroup.job_group_id shouldBe 1 - jobGroup.absolute_parent_id shouldBe 0 + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => + jobGroup.batch_id shouldBe backend.batchConfig.batchId + jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId + val jobs = jobGroup.jobs jobs.length shouldEqual contexts.length jobs.foreach { payload => payload.regions.value shouldBe rpcConfig.regions @@ -65,15 +50,15 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV ) } - (37, 1) + backend.batchConfig.jobGroupId + 1 } // the service backend expects that each job write its output to a well-known // location when it finishes. when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { (id: Int, jobGroupId: Int) => - id shouldEqual batchId - jobGroupId shouldEqual 1 + id shouldEqual backend.batchConfig.batchId + jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 val resultsDir = Path(backend.serviceBackendContext.remoteTmpDir) / @@ -106,7 +91,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV failure.foreach(throw _) batchClient.newBatch(any) wasCalled once - batchClient.newJobGroup(any, any, any, any) wasCalled once + batchClient.newJobGroup(any) wasCalled once } } @@ -114,12 +99,11 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV val flags = HailFeatureFlags.fromEnv() val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) new ServiceBackend( - jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", + jarSpec = GitRevision("123"), name = "name", theHailClassLoader = new HailClassLoader(getClass.getClassLoader), batchClient = client, - curBatchId = None, - curJobGroupId = None, + batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), flags = flags, tmpdir = rpcConfig.tmp_dir, fs = fs, diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 4fe1ba353bf..759f3f8232c 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -1,41 +1,100 @@ package is.hail.services +import is.hail.backend.service.Main import is.hail.utils._ +import scala.sys.process._ + +import java.lang.reflect.Method import java.nio.file.Path import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test -import sourcecode.FullName +import org.testng.annotations.{AfterClass, BeforeClass, BeforeMethod, Test} class BatchClientSuite extends TestNGSuite { - @Test def testBasic(): Unit = - using(BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json"))) { client => - val jobGroup = client.run( - BatchRequest( - billing_project = "test", - n_jobs = 0, + + private[this] var client: BatchClient = _ + private[this] var batchId: Int = _ + private[this] var parentJobGroupId: Int = _ + + @BeforeClass + def createClientAndBatch(): Unit = { + client = BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json")) + batchId = client.newBatch( + BatchRequest( + billing_project = "test", + n_jobs = 0, + token = tokenUrlSafe, + attributes = Map("name" -> s"${getClass.getName}"), + ) + ) + } + + @BeforeMethod + def createEmptyParentJobGroup(m: Method): Unit = { + parentJobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = 0, + token = tokenUrlSafe, + attributes = Map("name" -> m.getName), + jobs = FastSeq(), + ) + ) + } + + @AfterClass + def closeClient(): Unit = + client.close() + + @Test + def testNewJobGroup(): Unit = + // The query driver submits a job group per stage with one job per partition + for (i <- 1 to 2) { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, token = tokenUrlSafe, - attributes = Map("name" -> s"Test ${implicitly[FullName].value}"), - ), - JobGroupRequest( - job_group_id = 1, - absolute_parent_id = 0, - ), - FastSeq( + attributes = Map("name" -> s"JobGroup$i"), + jobs = (1 to i).map { k => + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", s"echo 'job $k'"), + ), + ) + }, + ) + ) + + val result = client.getJobGroup(batchId, jobGroupId) + assert(result.n_jobs == i) + } + + @Test + def testJvmJob(): Unit = { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, + token = tokenUrlSafe, + attributes = Map("name" -> "TableStage"), + jobs = FastSeq( JobRequest( - job_id = 1, always_run = false, - in_update_job_group_id = 0, - in_update_parent_ids = Array(), - process = BashJob( - image = "ubuntu:22.04", - command = Array("/bin/bash", "-c", "echo 'hello, hail!'"), + process = JvmJob( + command = Array(Main.WORKER, "", "", ""), + spec = GitRevision("git rev-parse main".!!.strip()), + profile = false, ), ) ), ) + ) - assert(jobGroup.state == JobGroupStates.Success) - } + val result = client.getJobGroup(batchId, jobGroupId) + assert(result.n_jobs == 1) + } } From eb6ce69a9614d2bcfec07ef811835228142ebcdc Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 17:22:29 -0400 Subject: [PATCH 08/20] fix scalatests --- hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala | 2 +- hail/src/test/scala/is/hail/services/BatchClientSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index c6aaa709f20..c66e1b0fcbd 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -90,8 +90,8 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV failure.foreach(throw _) - batchClient.newBatch(any) wasCalled once batchClient.newJobGroup(any) wasCalled once + batchClient.waitForJobGroup(any, any) wasCalled once } } diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 759f3f8232c..38e95fb651f 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -19,7 +19,7 @@ class BatchClientSuite extends TestNGSuite { @BeforeClass def createClientAndBatch(): Unit = { - client = BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json")) + client = BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json")) batchId = client.newBatch( BatchRequest( billing_project = "test", From 5480673ad4688f8cb365b08cb5deba1e491acbb4 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 18:01:22 -0400 Subject: [PATCH 09/20] job ids are 1-based --- hail/src/main/scala/is/hail/backend/service/Main.scala | 2 ++ hail/src/main/scala/is/hail/services/BatchClient.scala | 2 +- hail/src/test/scala/is/hail/services/BatchClientSuite.scala | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index 698f5ffa23c..e6c03bd6596 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -3,11 +3,13 @@ package is.hail.backend.service object Main { val WORKER = "worker" val DRIVER = "driver" + val TEST = "test" def main(argv: Array[String]): Unit = argv(3) match { case WORKER => Worker.main(argv) case DRIVER => ServiceBackendAPI.main(argv) + case TEST => () case kind => throw new RuntimeException(s"unknown kind: $kind") } } diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index a79dc7fcead..6df01fd0f51 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -238,7 +238,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab .asInstanceOf[JObject] .merge( JObject( - "job_id" -> JInt(jobIdx), + "job_id" -> JInt(jobIdx + 1), "in_update_job_group_id" -> JInt(1), ) ) diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 38e95fb651f..b7d0692b567 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -85,7 +85,7 @@ class BatchClientSuite extends TestNGSuite { JobRequest( always_run = false, process = JvmJob( - command = Array(Main.WORKER, "", "", ""), + command = Array(Main.TEST), spec = GitRevision("git rev-parse main".!!.strip()), profile = false, ), From b1e099401bae78b708e3200277f453f4448a0adf Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 18:49:15 -0400 Subject: [PATCH 10/20] depend on uploaded qob jar --- build.yaml | 1 + hail/src/test/scala/is/hail/services/BatchClientSuite.scala | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build.yaml b/build.yaml index b3b3eb68d1c..3d6645220cd 100644 --- a/build.yaml +++ b/build.yaml @@ -3720,6 +3720,7 @@ steps: - hail_run_image - build_debug_hail_test_jar - build_hail_test_artifacts + - upload_query_jar - deploy_batch - kind: runImage name: start_hail_benchmark diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index b7d0692b567..26671c4b3ff 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -1,10 +1,9 @@ package is.hail.services +import is.hail.HAIL_REVISION import is.hail.backend.service.Main import is.hail.utils._ -import scala.sys.process._ - import java.lang.reflect.Method import java.nio.file.Path @@ -86,7 +85,7 @@ class BatchClientSuite extends TestNGSuite { always_run = false, process = JvmJob( command = Array(Main.TEST), - spec = GitRevision("git rev-parse main".!!.strip()), + spec = GitRevision(HAIL_REVISION), profile = false, ), ) From 337e52e52d48ce61b672755da9d64466042fe0c2 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 19:08:09 -0400 Subject: [PATCH 11/20] jar location is a url, not a json spec --- .../hail/backend/service/ServiceBackend.scala | 12 +++---- .../scala/is/hail/services/BatchClient.scala | 34 ++++++++----------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 2506dfcab53..979c1428294 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -53,7 +53,7 @@ object ServiceBackend { private val log = Logger.getLogger(getClass.getName()) def apply( - jarSpec: JarSpec, + jarLocation: String, name: String, theHailClassLoader: HailClassLoader, batchClient: BatchClient, @@ -86,7 +86,7 @@ object ServiceBackend { ) val backend = new ServiceBackend( - jarSpec, + JarUrl(jarLocation), name, theHailClassLoader, batchClient, @@ -397,16 +397,14 @@ object ServiceBackendAPI { val scratchDir = argv(0) // val logFile = argv(1) - val jarSpecStr = argv(2) + val jarLocation = argv(2) val kind = argv(3) assert(kind == Main.DRIVER) val name = argv(4) val inputURL = argv(5) val outputURL = argv(6) - implicit val formats: Formats = DefaultFormats + JarSpecFormats - - val jarGitRevision = JsonMethods.parse(jarSpecStr).extract[JarSpec] + implicit val formats: Formats = DefaultFormats val fs = RouterFS.buildRoutes( CloudStorageFSConfig.fromFlagsAndEnv( @@ -430,7 +428,7 @@ object ServiceBackendAPI { // FIXME: when can the classloader be shared? (optimizer benefits!) val backend = ServiceBackend( - jarGitRevision, + jarLocation, name, new HailClassLoader(getClass().getClassLoader()), batchClient, diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 6df01fd0f51..571c0327942 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -50,25 +50,6 @@ sealed trait JarSpec case class GitRevision(sha: String) extends JarSpec case class JarUrl(url: String) extends JarSpec -object JarSpecFormats extends CustomSerializer[JarSpec](implicit fmts => - ( - { - case obj: JObject => - val value = (obj \ "value").extract[String] - (obj \ "type").extract[String] match { - case "jar_url" => JarUrl(value) - case "git_revision" => GitRevision(value) - } - }, - { - case JarUrl(url) => - JObject("type" -> JString("jar_url"), "value" -> JString(url)) - case GitRevision(sha) => - JObject("type" -> JString("git_revision"), "value" -> JString(sha)) - }, - ) - ) - case class JobResources( preemptible: Boolean, cpu: Option[String], @@ -138,7 +119,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab JobProcessRequestSerializer + JobGroupStateDeserializer + JobGroupResponseDeserializer + - JarSpecFormats + JarSpecSerializer def newBatch(createRequest: BatchRequest): Int = { val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest)) @@ -331,4 +312,17 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab PartialFunction.empty, ) ) + + private[this] object JarSpecSerializer + extends CustomSerializer[JarSpec](_ => + ( + PartialFunction.empty, + { + case JarUrl(url) => + JObject("type" -> JString("jar_url"), "value" -> JString(url)) + case GitRevision(sha) => + JObject("type" -> JString("git_revision"), "value" -> JString(sha)) + }, + ) + ) } From 64db2276823ccd76a786ad349714091046c519e4 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 19:34:32 -0400 Subject: [PATCH 12/20] log request after the fact --- hail/src/main/scala/is/hail/services/requests.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala index 07ac0806dff..b6ec90a08b3 100644 --- a/hail/src/main/scala/is/hail/services/requests.scala +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -57,17 +57,18 @@ object requests { } def request(req: HttpUriRequest, body: Option[HttpEntity] = None): JValue = { - log.info(s"request ${req.getMethod} ${req.getURI}") req.addHeader("Authorization", s"Bearer ${cred.accessToken}") body.foreach(entity => req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(entity)) retryTransientErrors { using(httpClient.execute(req)) { resp => val statusCode = resp.getStatusLine.getStatusCode - log.info(s"request ${req.getMethod} ${req.getURI} response $statusCode") val message = Option(resp.getEntity).map(EntityUtils.toString).filter(_.nonEmpty) if (statusCode < 200 || statusCode >= 300) { + log.warn(s"$statusCode ${req.getMethod} ${req.getURI}\n${message.orNull}") throw new ClientResponseException(statusCode, message.orNull) } + + log.info(s"$statusCode ${req.getMethod} ${req.getURI}") message.map(JsonMethods.parse(_)).getOrElse(JNothing) } } From 333068d6ea5e6ac51d7209176d87396fca7f18e4 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 23 Oct 2024 21:28:02 -0400 Subject: [PATCH 13/20] use rstrip --- hail/python/hail/context.py | 2 +- hail/python/hailtop/config/user_config.py | 2 +- hail/python/hailtop/hailctl/batch/submit.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 5258f27fbc1..1055c802b7b 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -573,7 +573,7 @@ async def init_batch( log = _get_log(log) if tmpdir is None: - tmpdir = backend.remote_tmpdir + 'tmp/hail/' + secret_alnum_string() + tmpdir = backend.remote_tmpdir + '/tmp/hail/' + secret_alnum_string() local_tmpdir = _get_local_tmpdir(local_tmpdir) HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend) diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index a6d003b4c12..55114bba48b 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -145,4 +145,4 @@ def get_remote_tmpdir( f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure' ) - return remote_tmpdir[:-1] if remote_tmpdir[-1] == '/' else remote_tmpdir + return remote_tmpdir.rstrip('/') diff --git a/hail/python/hailtop/hailctl/batch/submit.py b/hail/python/hailtop/hailctl/batch/submit.py index 21547f8e0b5..196b844ab23 100644 --- a/hail/python/hailtop/hailctl/batch/submit.py +++ b/hail/python/hailtop/hailctl/batch/submit.py @@ -29,7 +29,6 @@ async def submit(name, image_name, files, output, script, arguments): quiet = output != 'text' remote_tmpdir = get_remote_tmpdir('hailctl batch submit') - remote_tmpdir = remote_tmpdir.rstrip('/') tmpdir_path_prefix = secret_alnum_string() From 539f063b97207acaa309c351cb5e8bde506d2e18 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Thu, 24 Oct 2024 00:19:08 -0400 Subject: [PATCH 14/20] preserve old, broken behaviour --- .../hail/backend/service/ServiceBackend.scala | 138 +++++++++--------- 1 file changed, 73 insertions(+), 65 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 979c1428294..663eb8371b1 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,5 +1,4 @@ package is.hail.backend.service - import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ @@ -156,52 +155,24 @@ class ServiceBackend( new String(bytes, StandardCharsets.UTF_8) } - private[this] def submitAndWaitForBatch( - _backendContext: BackendContext, - fs: FS, + private[this] def submitJobGroupAndWait( + backendContext: ServiceBackendContext, collection: IndexedSeq[Array[Byte]], + token: String, + root: String, stageIdentifier: String, - f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte], - ): (String, String, Int) = { - val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] - val n = collection.length - val token = tokenUrlSafe - val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" - - log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n") - log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts") - - val uploadFunction = executor.submit[Unit](() => - retryTransientErrors { - fs.writePDOS(s"$root/f") { fos => - using(new ObjectOutputStream(fos))(oos => oos.writeObject(f)) - } - } - ) - - val uploadContexts = executor.submit[Unit](() => - retryTransientErrors { - fs.writePDOS(s"$root/contexts") { os => - var o = 12L * n // 12L = sizeof(Long) + sizeof(Int) - collection.foreach { context => - val len = context.length - os.writeLong(o) - os.writeInt(len) - o += len - } - collection.foreach(context => os.write(context)) - } - } - ) + ): JobGroupResponse = { + val defaultProcess = + JvmJob( + command = null, + spec = jarSpec, + profile = flags.get("profile") != null, + ) - val jobs = collection.indices.map { i => + val defaultJob = JobRequest( always_run = false, - process = JvmJob( - command = Array(Main.WORKER, root, s"$i", s"$n"), - spec = jarSpec, - profile = flags.get("profile") != null, - ), + process = null, resources = Some( JobResources( preemptible = true, @@ -212,34 +183,33 @@ class ServiceBackend( ), regions = Some(backendContext.regions).filter(_.nonEmpty), cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), - attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"), ) - } - uploadFunction.get() - uploadContexts.get() - - log.info(s"parallelizeAndComputeWithIndex: $token: running job") + val jobs = + collection.indices.map { i => + defaultJob.copy( + attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"), + process = defaultProcess.copy( + command = Array(Main.WORKER, root, s"$i", s"${collection.length}") + ), + ) + } - val jobGroupId = batchClient.newJobGroup( - JobGroupRequest( - batch_id = batchConfig.batchId, - absolute_parent_id = batchConfig.jobGroupId, - token = tokenUrlSafe, - attributes = Map("name" -> stageIdentifier), - jobs = jobs, + val jobGroupId = + batchClient.newJobGroup( + JobGroupRequest( + batch_id = batchConfig.batchId, + absolute_parent_id = batchConfig.jobGroupId, + token = token, + attributes = Map("name" -> stageIdentifier), + jobs = jobs, + ) ) - ) - - Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms - val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) stageCount += 1 - if (response.state == Failure) { - throw new HailBatchFailure(s"JobGroup $jobGroupId for batch ${batchConfig.batchId} failed") - } - (token, root, n) + Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms + batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) } private[this] def readResult(root: String, i: Int): Array[Byte] = { @@ -267,13 +237,45 @@ class ServiceBackend( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { + val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] + + val token = tokenUrlSafe + val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" + + val uploadFunction = executor.submit[Unit](() => + retryTransientErrors { + fs.writePDOS(s"$root/f") { fos => + using(new ObjectOutputStream(fos))(oos => oos.writeObject(f)) + log.info(s"parallelizeAndComputeWithIndex: $token: uploaded f") + } + } + ) + val (partIdxs, parts) = partitions .map(ps => (ps, ps.map(contexts))) .getOrElse((contexts.indices, contexts)) - val (token, root, _) = - submitAndWaitForBatch(_backendContext, fs, parts, stageIdentifier, f) + val uploadContexts = executor.submit[Unit](() => + retryTransientErrors { + fs.writePDOS(s"$root/contexts") { os => + var o = 12L * parts.length // 12L = sizeof(Long) + sizeof(Int) + parts.foreach { context => + val len = context.length + os.writeLong(o) + os.writeInt(len) + o += len + } + parts.foreach(os.write) + log.info(s"parallelizeAndComputeWithIndex: $token: wrote ${parts.length} contexts") + } + } + ) + + uploadFunction.get() + uploadContexts.get() + + val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() @@ -285,6 +287,12 @@ class ServiceBackend( error.foreach(throw _) + if (jobGroup.state == Failure) { + throw new HailBatchFailure( + s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error" + ) + } + val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 val rate = results.length / resultsReadingSeconds val byterate = results.map(_._1.length).sum / resultsReadingSeconds / 1024 / 1024 From 69a5c38fcab5c914411d35aff1a4986d043b0438 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Thu, 24 Oct 2024 00:48:25 -0400 Subject: [PATCH 15/20] scala fucking format at 1am --- hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 663eb8371b1..b0d06b8b0dc 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,4 +1,5 @@ package is.hail.backend.service + import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ From 2354bb3f0c657a611576069d1cf7b0f10e15d796 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Fri, 25 Oct 2024 11:28:01 -0400 Subject: [PATCH 16/20] refrsh azure tokens if they expire in 5 mins? --- hail/src/main/scala/is/hail/services/oauth2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala index 7bce6ac70d2..5063a51bb8e 100644 --- a/hail/src/main/scala/is/hail/services/oauth2.scala +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -88,7 +88,7 @@ object oauth2 { } private[this] def isExpired: Boolean = - token == null || OffsetDateTime.now.plusHours(1).isBefore(token.getExpiresAt) + token == null || OffsetDateTime.now.plusMinutes(5).isBefore(token.getExpiresAt) } object AzureCloudCredentials { From 5642d5f5f1a3743b34e9d4a3adaa22bd10266710 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 29 Oct 2024 13:44:53 -0400 Subject: [PATCH 17/20] feedback from @chrisvittal --- hail/python/hail/context.py | 2 +- hail/src/main/scala/is/hail/backend/Backend.scala | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 1055c802b7b..77493ecd083 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -573,7 +573,7 @@ async def init_batch( log = _get_log(log) if tmpdir is None: - tmpdir = backend.remote_tmpdir + '/tmp/hail/' + secret_alnum_string() + tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string()) local_tmpdir = _get_local_tmpdir(local_tmpdir) HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend) diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index e0444e371e2..329ee1a3e38 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -123,8 +123,6 @@ abstract class Backend extends Closeable { f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) - def close(): Unit - def asSpark(op: String): SparkBackend = fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") From 6e36a7afd5f62794c9484ccdcebb8c8c437767f2 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 29 Oct 2024 17:27:09 -0400 Subject: [PATCH 18/20] dont throw in parallelizeAndComputeWithIndex --- .../hail/backend/service/ServiceBackend.scala | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index b0d06b8b0dc..14970d452a1 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -15,7 +15,7 @@ import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix import is.hail.services.{BatchClient, JobGroupRequest, _} -import is.hail.services.JobGroupStates.Failure +import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -280,18 +280,27 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - val r @ (error, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { + var r @ (err, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => (() => readResult(root, jobIndex), partIdx) } } - error.foreach(throw _) + if (jobGroup.state != Success && err.isEmpty) { + assert(jobGroup.state != Running) + val error = + jobGroup.state match { + case Failure => + new HailBatchFailure( + s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error" + ) + case Cancelled => + new CancellationException( + s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled" + ) + } - if (jobGroup.state == Failure) { - throw new HailBatchFailure( - s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error" - ) + r = (Some(error), results) } val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 From 06c57e5bc3fe25070a0d00efd19778d8870c392a Mon Sep 17 00:00:00 2001 From: grohli <22306963+grohli@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:26:27 -0400 Subject: [PATCH 19/20] [qob] cancel stage if any partitions fail. --- .../hail/backend/service/ServiceBackend.scala | 12 +++++-- .../scala/is/hail/services/BatchClient.scala | 14 +++++--- .../is/hail/services/BatchClientSuite.scala | 33 +++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 14970d452a1..d9e0b40777f 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,6 +1,6 @@ package is.hail.backend.service -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ @@ -202,6 +202,7 @@ class ServiceBackend( batch_id = batchConfig.batchId, absolute_parent_id = batchConfig.jobGroupId, token = token, + cancel_after_n_failures = Some(1), attributes = Map("name" -> stageIdentifier), jobs = jobs, ) @@ -280,12 +281,17 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - var r @ (err, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { + var r @ (err, results) = runAll[Option, Array[Byte]](executor) { + /* A missing file means the job was cancelled because another job failed. Assumes that if any + * job was cancelled, then at least one job failed. We want to ignore the missing file + * exceptions and return one of the actual failure exceptions. */ + case (opt, _: FileNotFoundException) => opt + case (opt, e) => opt.orElse(Some(e)) + }(None) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => (() => readResult(root, jobIndex), partIdx) } } - if (jobGroup.state != Success && err.isEmpty) { assert(jobGroup.state != Running) val error = diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 571c0327942..8b956b7dc6e 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -14,7 +14,9 @@ import java.nio.file.Path import org.apache.http.entity.ByteArrayEntity import org.apache.http.entity.ContentType.APPLICATION_JSON -import org.json4s.{CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JObject, JString} +import org.json4s.{ + CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JNull, JObject, JString, +} import org.json4s.JsonAST.{JArray, JBool} import org.json4s.jackson.JsonMethods @@ -29,6 +31,7 @@ case class JobGroupRequest( batch_id: Int, absolute_parent_id: Int, token: String, + cancel_after_n_failures: Option[Int] = None, attributes: Map[String, String] = Map.empty, jobs: IndexedSeq[JobRequest] = FastSeq(), ) @@ -52,9 +55,9 @@ case class JarUrl(url: String) extends JarSpec case class JobResources( preemptible: Boolean, - cpu: Option[String], - memory: Option[String], - storage: Option[String], + cpu: Option[String] = None, + memory: Option[String] = None, + storage: Option[String] = None, ) case class CloudfuseConfig( @@ -252,6 +255,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab JObject( "job_group_id" -> JInt(1), // job group id relative to the update "absolute_parent_id" -> JInt(jobGroup.absolute_parent_id), + "cancel_after_n_failures" -> jobGroup.cancel_after_n_failures.map(JInt(_)).getOrElse( + JNull + ), "attributes" -> Extraction.decompose(jobGroup.attributes), ) )), diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 26671c4b3ff..04b69e0997c 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -2,6 +2,7 @@ package is.hail.services import is.hail.HAIL_REVISION import is.hail.backend.service.Main +import is.hail.services.JobGroupStates.Failure import is.hail.utils._ import java.lang.reflect.Method @@ -46,6 +47,38 @@ class BatchClientSuite extends TestNGSuite { def closeClient(): Unit = client.close() + @Test + def testCancelAfterNFailures(): Unit = { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, + cancel_after_n_failures = Some(1), + token = tokenUrlSafe, + jobs = FastSeq( + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "sleep 1d"), + ), + resources = Some(JobResources(preemptible = true)), + ), + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "exit 1"), + ), + ), + ), + ) + ) + val result = client.waitForJobGroup(batchId, jobGroupId) + assert(result.state == Failure) + assert(result.n_cancelled == 1) + } + @Test def testNewJobGroup(): Unit = // The query driver submits a job group per stage with one job per partition From daa4a2ec4e944976efb4f1b699269746c654c8c5 Mon Sep 17 00:00:00 2001 From: grohli <22306963+grohli@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:29:25 -0400 Subject: [PATCH 20/20] Change duration of job in test --- hail/src/test/scala/is/hail/services/BatchClientSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 04b69e0997c..079870efce2 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -60,7 +60,7 @@ class BatchClientSuite extends TestNGSuite { always_run = false, process = BashJob( image = "ubuntu:22.04", - command = Array("/bin/bash", "-c", "sleep 1d"), + command = Array("/bin/bash", "-c", "sleep 5m"), ), resources = Some(JobResources(preemptible = true)), ),