Skip to content

Commit 525c3e7

Browse files
committed
reduce code duplication
1 parent 1cd1263 commit 525c3e7

File tree

14 files changed

+91
-91
lines changed

14 files changed

+91
-91
lines changed

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
1010
import is.hail.expr.ir.analyses.SemanticHash
11+
import is.hail.expr.ir.compile.Compile
1112
import is.hail.expr.ir.lowering._
1213
import is.hail.io.fs._
1314
import is.hail.types._

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import is.hail.asm4s._
66
import is.hail.backend._
77
import is.hail.expr.Validate
88
import is.hail.expr.ir.{
9-
Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader,
10-
TypeCheck,
9+
IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck,
1110
}
1211
import is.hail.expr.ir.analyses.SemanticHash
12+
import is.hail.expr.ir.compile.Compile
1313
import is.hail.expr.ir.functions.IRFunctionRegistry
1414
import is.hail.expr.ir.lowering._
1515
import is.hail.io.fs._

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
99
import is.hail.expr.Validate
1010
import is.hail.expr.ir._
1111
import is.hail.expr.ir.analyses.SemanticHash
12+
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.lowering._
1314
import is.hail.io.{BufferSpec, TypedCodecSpec}
1415
import is.hail.io.fs._

hail/src/main/scala/is/hail/expr/ir/Compile.scala

Lines changed: 61 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ import is.hail.types.physical.stypes.{
1313
PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType,
1414
}
1515
import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream}
16-
import is.hail.types.virtual.Type
1716
import is.hail.utils._
1817

1918
import java.io.PrintWriter
2019

20+
import sourcecode.Enclosing
21+
2122
case class CodeCacheKey(
2223
aggSigs: IndexedSeq[AggStateSig],
2324
args: Seq[(Name, EmitParamType)],
@@ -32,8 +33,9 @@ case class CompiledFunction[T](
3233
(typ, f)
3334
}
3435

35-
object Compile {
36-
def apply[F: TypeInfo](
36+
object compile {
37+
38+
def Compile[F: TypeInfo](
3739
ctx: ExecuteContext,
3840
params: IndexedSeq[(Name, EmitParamType)],
3941
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
@@ -42,27 +44,69 @@ object Compile {
4244
optimize: Boolean = true,
4345
print: Option[PrintWriter] = None,
4446
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) =
47+
Impl[F, AnyVal](
48+
ctx,
49+
params,
50+
None,
51+
expectedCodeParamTypes,
52+
expectedCodeReturnType,
53+
body,
54+
optimize,
55+
print,
56+
)
57+
58+
def CompileWithAggregators[F: TypeInfo](
59+
ctx: ExecuteContext,
60+
aggSigs: Array[AggStateSig],
61+
params: IndexedSeq[(Name, EmitParamType)],
62+
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
63+
expectedCodeReturnType: TypeInfo[_],
64+
body: IR,
65+
optimize: Boolean = true,
66+
print: Option[PrintWriter] = None,
67+
): (
68+
Option[SingleCodeType],
69+
(HailClassLoader, FS, HailTaskContext, Region) => F with FunctionWithAggRegion,
70+
) =
71+
Impl[F, FunctionWithAggRegion](
72+
ctx,
73+
params,
74+
Some(aggSigs),
75+
expectedCodeParamTypes,
76+
expectedCodeReturnType,
77+
body,
78+
optimize,
79+
print,
80+
)
81+
82+
private[this] def Impl[F: TypeInfo, Mixin](
83+
ctx: ExecuteContext,
84+
params: IndexedSeq[(Name, EmitParamType)],
85+
aggSigs: Option[Array[AggStateSig]],
86+
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
87+
expectedCodeReturnType: TypeInfo[_],
88+
body: IR,
89+
optimize: Boolean,
90+
print: Option[PrintWriter],
91+
)(implicit
92+
E: Enclosing,
93+
N: sourcecode.Name,
94+
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
4595
ctx.time {
4696
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
4797
ctx.CodeCache.getOrElseUpdate(
48-
CodeCacheKey(FastSeq(), params.map { case (n, pt) => (n, pt) }, normalizedBody), {
49-
var ir = body
50-
ir = Subst(
51-
ir,
52-
BindingEnv(params
53-
.zipWithIndex
54-
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
98+
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
99+
var ir = Subst(
100+
body,
101+
BindingEnv(Env.fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In(i, t) })),
55102
)
56103
ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx)
57-
58104
TypeCheck(ctx, ir)
59105

60106
val fb = EmitFunctionBuilder[F](
61107
ctx,
62-
"Compiled",
63-
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) =>
64-
pt
65-
},
108+
N.value,
109+
CodeParamType(typeInfo[Region]) +: params.map(_._2),
66110
CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)),
67111
Some("Emit.scala"),
68112
)
@@ -83,65 +127,10 @@ object Compile {
83127
)
84128

85129
val emitContext = EmitContext.analyze(ctx, ir)
86-
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length)
130+
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, aggSigs)
87131
CompiledFunction(rt, fb.resultWithIndex(print))
88132
},
89-
).asInstanceOf[CompiledFunction[F]].tuple
90-
}
91-
}
92-
93-
object CompileWithAggregators {
94-
def apply[F: TypeInfo](
95-
ctx: ExecuteContext,
96-
aggSigs: Array[AggStateSig],
97-
params: IndexedSeq[(Name, EmitParamType)],
98-
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
99-
expectedCodeReturnType: TypeInfo[_],
100-
body: IR,
101-
optimize: Boolean = true,
102-
): (
103-
Option[SingleCodeType],
104-
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
105-
) =
106-
ctx.time {
107-
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
108-
ctx.CodeCache.getOrElseUpdate(
109-
CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody), {
110-
var ir = body
111-
ir = Subst(
112-
ir,
113-
BindingEnv(params
114-
.zipWithIndex
115-
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
116-
)
117-
ir =
118-
LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)
119-
120-
TypeCheck(
121-
ctx,
122-
ir,
123-
BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })),
124-
)
125-
126-
val fb = EmitFunctionBuilder[F with FunctionWithAggRegion](
127-
ctx,
128-
"CompiledWithAggs",
129-
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt },
130-
SingleCodeType.typeInfoFromType(ir.typ),
131-
Some("Emit.scala"),
132-
)
133-
134-
/* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
135-
* x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c)
136-
* } }
137-
*
138-
* visit(ir) } */
139-
140-
val emitContext = EmitContext.analyze(ctx, ir)
141-
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs))
142-
CompiledFunction(rt, fb.resultWithIndex())
143-
},
144-
).asInstanceOf[CompiledFunction[F with FunctionWithAggRegion]].tuple
133+
).asInstanceOf[CompiledFunction[F with Mixin]].tuple
145134
}
146135
}
147136

hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package is.hail.expr.ir
33
import is.hail.annotations.{Region, SafeRow}
44
import is.hail.asm4s._
55
import is.hail.backend.ExecuteContext
6+
import is.hail.expr.ir.compile.Compile
67
import is.hail.expr.ir.lowering.LoweringPipeline
78
import is.hail.types.physical.PTuple
89
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType

hail/src/main/scala/is/hail/expr/ir/Emit.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import is.hail.expr.ir.agg.{AggStateSig, ArrayAggStateSig, GroupedStateSig}
77
import is.hail.expr.ir.analyses.{
88
ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash,
99
}
10+
import is.hail.expr.ir.compile.Compile
1011
import is.hail.expr.ir.lowering.TableStageDependency
1112
import is.hail.expr.ir.ndarrays.EmitNDArray
1213
import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils}

hail/src/main/scala/is/hail/expr/ir/Interpret.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.annotations._
44
import is.hail.asm4s._
55
import is.hail.backend.{ExecuteContext, HailTaskContext}
66
import is.hail.backend.spark.SparkTaskContext
7+
import is.hail.expr.ir.compile.{Compile, CompileWithAggregators}
78
import is.hail.expr.ir.lowering.LoweringPipeline
89
import is.hail.io.BufferSpec
910
import is.hail.linalg.BlockMatrix

hail/src/main/scala/is/hail/expr/ir/TableIR.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import is.hail.annotations._
55
import is.hail.asm4s._
66
import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer}
77
import is.hail.backend.spark.{SparkBackend, SparkTaskContext}
8-
import is.hail.expr.ir
8+
import is.hail.expr.ir.compile.{Compile, CompileWithAggregators}
99
import is.hail.expr.ir.functions.{
1010
BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction,
1111
}
@@ -1931,7 +1931,7 @@ case class TableNativeZippedReader(
19311931
val leftRef = Ref(freshName(), pLeft.virtualType)
19321932
val rightRef = Ref(freshName(), pRight.virtualType)
19331933
val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) =
1934-
ir.Compile[AsmFunction3RegionLongLongLong](
1934+
Compile[AsmFunction3RegionLongLongLong](
19351935
ctx,
19361936
FastSeq(
19371937
leftRef.name -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)),
@@ -2420,7 +2420,7 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR {
24202420
else if (pred == False())
24212421
return TableValueIntermediate(tv.copy(rvd = RVD.empty(ctx, typ.canonicalRVDType)))
24222422

2423-
val (Some(BooleanSingleCodeType), f) = ir.Compile[AsmFunction3RegionLongLongBoolean](
2423+
val (Some(BooleanSingleCodeType), f) = Compile[AsmFunction3RegionLongLongBoolean](
24242424
ctx,
24252425
FastSeq(
24262426
(
@@ -3035,7 +3035,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {
30353035

30363036
if (extracted.aggs.isEmpty) {
30373037
val (Some(PTypeReferenceSingleCodeType(rTyp)), f) =
3038-
ir.Compile[AsmFunction3RegionLongLongLong](
3038+
Compile[AsmFunction3RegionLongLongLong](
30393039
ctx,
30403040
FastSeq(
30413041
(
@@ -3101,7 +3101,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {
31013101
// 3. load in partition aggregations, comb op as necessary, serialize.
31023102
// 4. load in partStarts, calculate newRow based on those results.
31033103

3104-
val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](
3104+
val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit](
31053105
ctx,
31063106
extracted.states,
31073107
FastSeq((
@@ -3115,7 +3115,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {
31153115

31163116
val serializeF = extracted.serialize(ctx, spec)
31173117

3118-
val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](
3118+
val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](
31193119
ctx,
31203120
extracted.states,
31213121
FastSeq(
@@ -3138,7 +3138,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {
31383138
val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec)
31393139

31403140
val (Some(PTypeReferenceSingleCodeType(rTyp)), f) =
3141-
ir.CompileWithAggregators[AsmFunction3RegionLongLongLong](
3141+
CompileWithAggregators[AsmFunction3RegionLongLongLong](
31423142
ctx,
31433143
extracted.states,
31443144
FastSeq(
@@ -3697,7 +3697,7 @@ case class TableKeyByAndAggregate(
36973697

36983698
val localKeyType = keyType
36993699
val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) =
3700-
ir.Compile[AsmFunction3RegionLongLongLong](
3700+
Compile[AsmFunction3RegionLongLongLong](
37013701
ctx,
37023702
FastSeq(
37033703
(
@@ -3723,7 +3723,7 @@ case class TableKeyByAndAggregate(
37233723

37243724
val extracted = agg.Extract(expr, Requiredness(this, ctx))
37253725

3726-
val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](
3726+
val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit](
37273727
ctx,
37283728
extracted.states,
37293729
FastSeq((
@@ -3735,7 +3735,7 @@ case class TableKeyByAndAggregate(
37353735
extracted.init,
37363736
)
37373737

3738-
val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](
3738+
val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](
37393739
ctx,
37403740
extracted.states,
37413741
FastSeq(
@@ -3754,7 +3754,7 @@ case class TableKeyByAndAggregate(
37543754
)
37553755

37563756
val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) =
3757-
ir.CompileWithAggregators[AsmFunction2RegionLongLong](
3757+
CompileWithAggregators[AsmFunction2RegionLongLong](
37583758
ctx,
37593759
extracted.states,
37603760
FastSeq((
@@ -3897,7 +3897,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
38973897

38983898
val extracted = agg.Extract(expr, Requiredness(this, ctx))
38993899

3900-
val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](
3900+
val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit](
39013901
ctx,
39023902
extracted.states,
39033903
FastSeq((
@@ -3909,7 +3909,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
39093909
extracted.init,
39103910
)
39113911

3912-
val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](
3912+
val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](
39133913
ctx,
39143914
extracted.states,
39153915
FastSeq(
@@ -3933,7 +3933,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
39333933
val key = Ref(freshName(), keyType.virtualType)
39343934
val value = Ref(freshName(), valueIR.typ)
39353935
val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) =
3936-
ir.CompileWithAggregators[AsmFunction3RegionLongLongLong](
3936+
CompileWithAggregators[AsmFunction3RegionLongLongLong](
39373937
ctx,
39383938
extracted.states,
39393939
FastSeq(

hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import is.hail.backend.{ExecuteContext, HailTaskContext}
66
import is.hail.backend.spark.SparkTaskContext
77
import is.hail.expr.ir
88
import is.hail.expr.ir._
9+
import is.hail.expr.ir.compile.CompileWithAggregators
910
import is.hail.io.BufferSpec
1011
import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq}
1112
import is.hail.types.physical.stypes.EmitType
@@ -247,7 +248,7 @@ class Aggs(
247248

248249
def deserialize(ctx: ExecuteContext, spec: BufferSpec)
249250
: ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = {
250-
val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](
251+
val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit](
251252
ctx,
252253
states,
253254
FastSeq(),
@@ -268,7 +269,7 @@ class Aggs(
268269

269270
def serialize(ctx: ExecuteContext, spec: BufferSpec)
270271
: (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = {
271-
val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](
272+
val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit](
272273
ctx,
273274
states,
274275
FastSeq(),
@@ -305,7 +306,7 @@ class Aggs(
305306
: (() => (RegionPool, HailClassLoader, HailTaskContext)) => (
306307
(Array[Byte], Array[Byte]) => Array[Byte],
307308
) = {
308-
val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](
309+
val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit](
309310
ctx,
310311
states ++ states,
311312
FastSeq(),

hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.annotations.{Annotation, ExtendedOrdering, Region, SafeRow}
44
import is.hail.asm4s.{classInfo, AsmFunction1RegionLong, LongInfo}
55
import is.hail.backend.{ExecuteContext, HailStateManager}
66
import is.hail.expr.ir._
7+
import is.hail.expr.ir.compile.Compile
78
import is.hail.expr.ir.functions.{ArrayFunctions, IRRandomness, UtilFunctions}
89
import is.hail.io.{BufferSpec, TypedCodecSpec}
910
import is.hail.rvd.RVDPartitioner

0 commit comments

Comments
 (0)