Skip to content

Commit

Permalink
wire up spark, local and py4j backends
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 1, 2024
1 parent 28e88c8 commit 6baee10
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 56 deletions.
15 changes: 3 additions & 12 deletions hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,14 @@ def __init__(
)
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)

super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc)
super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration
self._fs = self._exit_stack.enter_context(
RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration})
)

self._logger = None

flags = {}
if gcs_requester_pays_configuration is not None:
if isinstance(gcs_requester_pays_configuration, str):
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration
else:
assert isinstance(gcs_requester_pays_configuration, tuple)
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0]
flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1])

self._initialize_flags(flags)
self._initialize_flags({})

def validate_file(self, uri: str) -> None:
async_to_blocking(validate_file(uri, self._fs.afs))
Expand Down
29 changes: 27 additions & 2 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

import hail
from hail.expr import construct_expr
from hail.fs.hadoop_fs import HadoopFS
from hail.ir import JavaIR
from hail.utils.java import Env, FatalError, scala_package_object
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration

from ..hail_logging import Logger
from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet
Expand Down Expand Up @@ -193,11 +195,17 @@ def decode_bytearray(encoded):
self._utils_package_object = scala_package_object(self._hail_package.utils)
self._jhc = jhc

self._jbackend = self._hail_package.backend.api.P4jBackendApi(jbackend)
self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend)
self._jbackend.pySetLocalTmp(tmpdir)
self._jbackend.pySetRemoteTmp(remote_tmpdir)

self._jhttp_server = self._jbackend.pyHttpServer()
self._backend_server_port: int = self._jbackend.HttpServer.port()
self._backend_server_port: int = self._jhttp_server.port()
self._requests_session = requests.Session()

self._gcs_requester_pays_config = None
self._fs = None

# This has to go after creating the SparkSession. Unclear why.
# Maybe it does its own patch?
install_exception_handler()
Expand All @@ -221,6 +229,23 @@ def hail_package(self):
def utils_package_object(self):
return self._utils_package_object

@property
def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]:
return self._gcs_requester_pays_config

@gcs_requester_pays_configuration.setter
def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]):
self._gcs_requester_pays_config = config
project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config
self._jbackend.pySetGcsRequesterPaysConfig(project, buckets)
self._fs = None # stale

@property
def fs(self):
if self._fs is None:
self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs())
return self._fs

@property
def logger(self):
if self._logger is None:
Expand Down
29 changes: 6 additions & 23 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import pyspark.sql

from hail.expr.table_type import ttable
from hail.fs.hadoop_fs import HadoopFS
from hail.ir import BaseIR
from hail.ir.renderer import CSERenderer
from hail.table import Table
from hail.utils import copy_log
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration
from hailtop.aiotools.router_fs import RouterAsyncFS
from hailtop.aiotools.validators import validate_file
from hailtop.utils import async_to_blocking
Expand Down Expand Up @@ -47,12 +47,9 @@ def __init__(
skip_logging_configuration,
optimizer_iterations,
*,
gcs_requester_pays_project: Optional[str] = None,
gcs_requester_pays_buckets: Optional[str] = None,
gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None,
copy_log_on_error: bool = False,
):
assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None

try:
local_jar_info = local_jar_information()
except ValueError:
Expand Down Expand Up @@ -120,10 +117,6 @@ def __init__(
append,
skip_logging_configuration,
min_block_size,
tmpdir,
local_tmpdir,
gcs_requester_pays_project,
gcs_requester_pays_buckets,
)
jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations)
else:
Expand All @@ -137,10 +130,6 @@ def __init__(
append,
skip_logging_configuration,
min_block_size,
tmpdir,
local_tmpdir,
gcs_requester_pays_project,
gcs_requester_pays_buckets,
)
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)

Expand All @@ -149,12 +138,12 @@ def __init__(
self.sc = sc
else:
self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc))
self._jspark_session = jbackend.sparkSession()
self._jspark_session = jbackend.sparkSession().apply()
self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session)

super(SparkBackend, self).__init__(jvm, jbackend, jhc)
super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_config

self._fs = None
self._logger = None

if not quiet:
Expand All @@ -167,7 +156,7 @@ def __init__(
self._initialize_flags({})

self._router_async_fs = RouterAsyncFS(
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project}
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config}
)

self._tmpdir = tmpdir
Expand All @@ -181,12 +170,6 @@ def stop(self):
self.sc.stop()
self.sc = None

@property
def fs(self):
if self._fs is None:
self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs())
return self._fs

def from_spark(self, df, key):
result_tuple = self._jbackend.pyFromDF(df._jdf, key)
tir_id, type_json = result_tuple._1(), result_tuple._2()
Expand Down
13 changes: 4 additions & 9 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,10 @@ def init_spark(
optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3)

app_name = app_name or 'Hail'
(
gcs_requester_pays_project,
gcs_requester_pays_buckets,
) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style(
get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
)
gcs_requester_pays_configuration = get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
)

backend = SparkBackend(
idempotent,
sc,
Expand All @@ -498,8 +494,7 @@ def init_spark(
local_tmpdir,
skip_logging_configuration,
optimizer_iterations,
gcs_requester_pays_project=gcs_requester_pays_project,
gcs_requester_pays_buckets=gcs_requester_pays_buckets,
gcs_requester_pays_config=gcs_requester_pays_configuration,
copy_log_on_error=copy_log_on_error,
)
if not backend.fs.exists(tmpdir):
Expand Down
20 changes: 14 additions & 6 deletions hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import is.hail.backend._
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.spark.SparkBackend
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue}
import is.hail.expr.ir.{
BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser,
Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue,
}
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand All @@ -21,14 +24,18 @@ import is.hail.utils._
import is.hail.utils.ExecutionTimer.Timings
import is.hail.variant.ReferenceGenome

import scala.annotation.nowarn
import scala.collection.mutable
import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter}

import java.io.Closeable
import java.net.InetSocketAddress
import java.util
import java.util.concurrent._

import com.google.api.client.http.HttpStatusCodes
import com.sun.net.httpserver.{HttpExchange, HttpServer}
import javax.annotation.Nullable
import org.apache.hadoop
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.DataFrame
Expand All @@ -37,9 +44,6 @@ import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

import javax.annotation.Nullable
import scala.annotation.nowarn

final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling {

private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
Expand Down Expand Up @@ -74,6 +78,9 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
manager.close()
}

def pyFs: FS =
tmpFileManager.getFs

def pyGetFlag(name: String): String =
flags.get(name)

Expand All @@ -83,13 +90,14 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
def pyAvailableFlags: java.util.ArrayList[String] =
flags.available

def pySetTmpdir(tmp: String): Unit =
def pySetRemoteTmp(tmp: String): Unit =
tmpdir = tmp

def pySetLocalTmp(tmp: String): Unit =
localTmpdir = tmp

def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = {
def pySetGcsRequesterPaysConfig(@Nullable project: String, @Nullable buckets: util.List[String])
: Unit = {
val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

val rpConfig: Option[RequesterPaysConfig] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS}
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.services._
import is.hail.types.virtual.Kinds
import is.hail.utils.{toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging}
import is.hail.utils.{
toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging,
}
import is.hail.utils.ExecutionTimer.Timings
import is.hail.variant.ReferenceGenome

Expand Down
1 change: 0 additions & 1 deletion hail/src/main/scala/is/hail/utils/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ package utils {
}
}


class Lazy[A] private[utils] (f: => A) {
private[this] var option: Option[A] = None

Expand Down
4 changes: 2 additions & 2 deletions hail/src/test/scala/is/hail/HailSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class HailSuite extends TestNGSuite with TestUtils {
var pool: RegionPool = _
private[this] var ctx_ : ExecuteContext = _

def backend: Backend = ctx.backend
def sc: SparkContext = backend.asSpark.sc
def backend: Backend = hc.backend
def sc: SparkContext = hc.backend.asSpark.sc
def timer: ExecutionTimer = ctx.timer
def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader
override def ctx: ExecuteContext = ctx_
Expand Down

0 comments on commit 6baee10

Please sign in to comment.