Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Remove lookupOrCompileCachedFunction from Backend interface #14691

Open
wants to merge 2 commits into
base: ehigham/ctx-bm-cache
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
Expand Down Expand Up @@ -105,9 +104,6 @@ abstract class Backend extends Closeable {

def shouldCacheQueryInfo: Boolean = true

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -206,23 +202,3 @@ abstract class Backend extends Closeable {

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

trait BackendWithCodeCache {
private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50)

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = {
codeCache.get(k) match {
case Some(v) => v.asInstanceOf[CompiledFunction[T]]
case None =>
val compiledFunction = f
codeCache += ((k, compiledFunction))
compiledFunction
}
}
}

trait BackendWithNoCodeCache {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = f
}
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -73,6 +74,7 @@ object ExecuteContext {
backendContext: BackendContext,
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -92,6 +94,7 @@ object ExecuteContext {
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f(_))
}
}
Expand Down Expand Up @@ -122,6 +125,7 @@ class ExecuteContext(
val backendContext: BackendContext,
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -191,6 +195,7 @@ class ExecuteContext(
backendContext: BackendContext = this.backendContext,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -208,5 +213,6 @@ class ExecuteContext(
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f)
}
7 changes: 5 additions & 2 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.types._
Expand Down Expand Up @@ -70,13 +71,14 @@ object LocalBackend {
class LocalBackend(
val tmpdir: String,
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -100,6 +102,7 @@ class LocalBackend(
},
new IrMetadata(),
ImmutableMap.empty,
codeCache,
)(f)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader,
TypeCheck,
IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
Expand Down Expand Up @@ -51,7 +51,6 @@ class ServiceBackendContext(
) extends BackendContext with Serializable {}

object ServiceBackend {
private val log = Logger.getLogger(getClass.getName())

def apply(
jarLocation: String,
Expand Down Expand Up @@ -129,8 +128,7 @@ class ServiceBackend(
val fs: FS,
val serviceBackendContext: ServiceBackendContext,
val scratchDir: String,
) extends Backend with BackendWithNoCodeCache {
import ServiceBackend.log
) extends Backend with Logging {

private[this] var stageCount = 0
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
Expand Down Expand Up @@ -391,6 +389,7 @@ class ServiceBackend(
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)(f)
}

Expand Down
9 changes: 6 additions & 3 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand Down Expand Up @@ -307,7 +308,7 @@ class SparkBackend(
override val references: mutable.Map[String, ReferenceGenome],
gcsRequesterPaysProject: String,
gcsRequesterPaysBuckets: String,
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null)
lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
Expand Down Expand Up @@ -338,8 +339,8 @@ class SparkBackend(
override val longLifeTempFileManager: TempFileManager =
new OwningTempFileManager(fs)

private[this] val bmCache: BlockMatrixCache =
new BlockMatrixCache()
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -363,6 +364,7 @@ class SparkBackend(
},
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -383,6 +385,7 @@ class SparkBackend(
},
new IrMetadata(),
bmCache,
codeCache,
)(f)
}

Expand Down
Loading