diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 09983e0f698..10498f6debf 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.types._ diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 6a34503fdd0..cb4fd37a901 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -6,10 +6,10 @@ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate import is.hail.expr.ir.{ - Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, - TypeCheck, + IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 3380e76a4cd..baad79e9b1e 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -9,6 +9,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index b54537947b1..dbe02dd5e30 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -13,11 +13,12 @@ import is.hail.types.physical.stypes.{ PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType, } import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream} -import is.hail.types.virtual.Type import is.hail.utils._ import java.io.PrintWriter +import sourcecode.Enclosing + case class CodeCacheKey( aggSigs: IndexedSeq[AggStateSig], args: Seq[(Name, EmitParamType)], @@ -32,8 +33,9 @@ case class CompiledFunction[T]( (typ, f) } -object Compile { - def apply[F: TypeInfo]( +object compile { + + def Compile[F: TypeInfo]( ctx: ExecuteContext, params: IndexedSeq[(Name, EmitParamType)], expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], @@ -42,27 +44,69 @@ object Compile { optimize: Boolean = true, print: Option[PrintWriter] = None, ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = + Impl[F, AnyVal]( + ctx, + params, + None, + expectedCodeParamTypes, + expectedCodeReturnType, + body, + optimize, + print, + ) + + def CompileWithAggregators[F: TypeInfo]( + ctx: ExecuteContext, + aggSigs: Array[AggStateSig], + params: IndexedSeq[(Name, EmitParamType)], + expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], + expectedCodeReturnType: TypeInfo[_], + body: IR, + optimize: Boolean = true, + print: Option[PrintWriter] = None, + ): ( + Option[SingleCodeType], + (HailClassLoader, FS, HailTaskContext, Region) => F with FunctionWithAggRegion, + ) = + Impl[F, FunctionWithAggRegion]( + ctx, + params, + Some(aggSigs), + expectedCodeParamTypes, + expectedCodeReturnType, + body, + optimize, + print, + ) + + private[this] def Impl[F: TypeInfo, Mixin]( + ctx: ExecuteContext, + params: IndexedSeq[(Name, EmitParamType)], + aggSigs: Option[Array[AggStateSig]], + expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], + expectedCodeReturnType: TypeInfo[_], + body: IR, + optimize: Boolean, + print: Option[PrintWriter], + )(implicit + E: Enclosing, + N: sourcecode.Name, + ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) = ctx.time { val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) ctx.CodeCache.getOrElseUpdate( - CodeCacheKey(FastSeq(), params.map { case (n, pt) => (n, pt) }, normalizedBody), { - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), + CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), { + var ir = Subst( + body, + BindingEnv(Env.fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In(i, t) })), ) ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx) - TypeCheck(ctx, ir) val fb = EmitFunctionBuilder[F]( ctx, - "Compiled", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => - pt - }, + N.value, + CodeParamType(typeInfo[Region]) +: params.map(_._2), CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)), Some("Emit.scala"), ) @@ -83,65 +127,10 @@ object Compile { ) val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length) + val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, aggSigs) CompiledFunction(rt, fb.resultWithIndex(print)) }, - ).asInstanceOf[CompiledFunction[F]].tuple - } -} - -object CompileWithAggregators { - def apply[F: TypeInfo]( - ctx: ExecuteContext, - aggSigs: Array[AggStateSig], - params: IndexedSeq[(Name, EmitParamType)], - expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], - expectedCodeReturnType: TypeInfo[_], - body: IR, - optimize: Boolean = true, - ): ( - Option[SingleCodeType], - (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion), - ) = - ctx.time { - val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) - ctx.CodeCache.getOrElseUpdate( - CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody), { - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck( - ctx, - ir, - BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })), - ) - - val fb = EmitFunctionBuilder[F with FunctionWithAggRegion]( - ctx, - "CompiledWithAggs", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt }, - SingleCodeType.typeInfoFromType(ir.typ), - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) - * } } - * - * visit(ir) } */ - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs)) - CompiledFunction(rt, fb.resultWithIndex()) - }, - ).asInstanceOf[CompiledFunction[F with FunctionWithAggRegion]].tuple + ).asInstanceOf[CompiledFunction[F with Mixin]].tuple } } diff --git a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala index 51257f1b1a9..b395a1ff84f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala +++ b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.annotations.{Region, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index f52325c6f23..6674d50ab3c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -7,6 +7,7 @@ import is.hail.expr.ir.agg.{AggStateSig, ArrayAggStateSig, GroupedStateSig} import is.hail.expr.ir.analyses.{ ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash, } +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.ndarrays.EmitNDArray import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils} diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index e7628c600d8..1eea7e096e2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -4,6 +4,7 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.backend.spark.SparkTaskContext +import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.io.BufferSpec import is.hail.linalg.BlockMatrix diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index cbd2b2bc254..4c528d79a98 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -5,7 +5,7 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer} import is.hail.backend.spark.{SparkBackend, SparkTaskContext} -import is.hail.expr.ir +import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} import is.hail.expr.ir.functions.{ BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction, } @@ -1931,7 +1931,7 @@ case class TableNativeZippedReader( val leftRef = Ref(freshName(), pLeft.virtualType) val rightRef = Ref(freshName(), pRight.virtualType) val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( leftRef.name -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)), @@ -2420,7 +2420,7 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR { else if (pred == False()) return TableValueIntermediate(tv.copy(rvd = RVD.empty(ctx, typ.canonicalRVDType))) - val (Some(BooleanSingleCodeType), f) = ir.Compile[AsmFunction3RegionLongLongBoolean]( + val (Some(BooleanSingleCodeType), f) = Compile[AsmFunction3RegionLongLongBoolean]( ctx, FastSeq( ( @@ -3035,7 +3035,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { if (extracted.aggs.isEmpty) { val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( ( @@ -3101,7 +3101,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { // 3. load in partition aggregations, comb op as necessary, serialize. // 4. load in partStarts, calculate newRow based on those results. - val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3115,7 +3115,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val serializeF = extracted.serialize(ctx, spec) - val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3138,7 +3138,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec) val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, extracted.states, FastSeq( @@ -3697,7 +3697,7 @@ case class TableKeyByAndAggregate( val localKeyType = keyType val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( ( @@ -3723,7 +3723,7 @@ case class TableKeyByAndAggregate( val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3735,7 +3735,7 @@ case class TableKeyByAndAggregate( extracted.init, ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3754,7 +3754,7 @@ case class TableKeyByAndAggregate( ) val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) = - ir.CompileWithAggregators[AsmFunction2RegionLongLong]( + CompileWithAggregators[AsmFunction2RegionLongLong]( ctx, extracted.states, FastSeq(( @@ -3897,7 +3897,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3909,7 +3909,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { extracted.init, ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3933,7 +3933,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val key = Ref(freshName(), keyType.virtualType) val value = Ref(freshName(), valueIR.typ) val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) = - ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, extracted.states, FastSeq( diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala index 4b4d5c3ac4e..84099594230 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala @@ -6,6 +6,7 @@ import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.backend.spark.SparkTaskContext import is.hail.expr.ir import is.hail.expr.ir._ +import is.hail.expr.ir.compile.CompileWithAggregators import is.hail.io.BufferSpec import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} import is.hail.types.physical.stypes.EmitType @@ -247,7 +248,7 @@ class Aggs( def deserialize(ctx: ExecuteContext, spec: BufferSpec) : ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states, FastSeq(), @@ -268,7 +269,7 @@ class Aggs( def serialize(ctx: ExecuteContext, spec: BufferSpec) : (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states, FastSeq(), @@ -305,7 +306,7 @@ class Aggs( : (() => (RegionPool, HailClassLoader, HailTaskContext)) => ( (Array[Byte], Array[Byte]) => Array[Byte], ) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states ++ states, FastSeq(), diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 859dfcaa5ba..aab36e0f847 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -4,6 +4,7 @@ import is.hail.annotations.{Annotation, ExtendedOrdering, Region, SafeRow} import is.hail.asm4s.{classInfo, AsmFunction1RegionLong, LongInfo} import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.functions.{ArrayFunctions, IRRandomness, UtilFunctions} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.rvd.RVDPartitioner diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index e3423bf9f75..e0af240efc5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -5,6 +5,7 @@ import is.hail.asm4s._ import is.hail.backend.{BroadcastValue, ExecuteContext} import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs.FS import is.hail.rvd.{RVD, RVDType} diff --git a/hail/src/test/scala/is/hail/TestUtils.scala b/hail/src/test/scala/is/hail/TestUtils.scala index 926cf7753d3..ff7a4ed4868 100644 --- a/hail/src/test/scala/is/hail/TestUtils.scala +++ b/hail/src/test/scala/is/hail/TestUtils.scala @@ -4,6 +4,7 @@ import is.hail.annotations.{Region, RegionValueBuilder, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LowererUnsupportedOperation import is.hail.io.vcf.MatrixVCFReader import is.hail.types.physical.{PBaseStruct, PCanonicalArray, PType} diff --git a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala index 9cd09daf7e6..910fef630e1 100644 --- a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala @@ -4,6 +4,7 @@ import is.hail.{ExecStrategy, HailSuite} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.expr.ir.agg._ +import is.hail.expr.ir.compile.CompileWithAggregators import is.hail.io.BufferSpec import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ diff --git a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala index 0c1b0393fe2..4ab89a11177 100644 --- a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala @@ -6,6 +6,7 @@ import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg.{CollectStateSig, PhysicalAggSig, TypedStateSig} +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.ir.streams.{EmitStream, StreamUtils} import is.hail.types.VirtualTypeWithReq