@@ -13,11 +13,12 @@ import is.hail.types.physical.stypes.{
1313 PTypeReferenceSingleCodeType , SingleCodeType , StreamSingleCodeType ,
1414}
1515import is .hail .types .physical .stypes .interfaces .{NoBoxLongIterator , SStream }
16- import is .hail .types .virtual .Type
1716import is .hail .utils ._
1817
1918import java .io .PrintWriter
2019
20+ import sourcecode .Enclosing
21+
2122case class CodeCacheKey (
2223 aggSigs : IndexedSeq [AggStateSig ],
2324 args : Seq [(Name , EmitParamType )],
@@ -32,8 +33,9 @@ case class CompiledFunction[T](
3233 (typ, f)
3334}
3435
35- object Compile {
36- def apply [F : TypeInfo ](
36+ object compile {
37+
38+ def Compile [F : TypeInfo ](
3739 ctx : ExecuteContext ,
3840 params : IndexedSeq [(Name , EmitParamType )],
3941 expectedCodeParamTypes : IndexedSeq [TypeInfo [_]],
@@ -42,27 +44,69 @@ object Compile {
4244 optimize : Boolean = true ,
4345 print : Option [PrintWriter ] = None ,
4446 ): (Option [SingleCodeType ], (HailClassLoader , FS , HailTaskContext , Region ) => F ) =
47+ Impl [F , AnyVal ](
48+ ctx,
49+ params,
50+ None ,
51+ expectedCodeParamTypes,
52+ expectedCodeReturnType,
53+ body,
54+ optimize,
55+ print,
56+ )
57+
58+ def CompileWithAggregators [F : TypeInfo ](
59+ ctx : ExecuteContext ,
60+ aggSigs : Array [AggStateSig ],
61+ params : IndexedSeq [(Name , EmitParamType )],
62+ expectedCodeParamTypes : IndexedSeq [TypeInfo [_]],
63+ expectedCodeReturnType : TypeInfo [_],
64+ body : IR ,
65+ optimize : Boolean = true ,
66+ print : Option [PrintWriter ] = None ,
67+ ): (
68+ Option [SingleCodeType ],
69+ (HailClassLoader , FS , HailTaskContext , Region ) => F with FunctionWithAggRegion ,
70+ ) =
71+ Impl [F , FunctionWithAggRegion ](
72+ ctx,
73+ params,
74+ Some (aggSigs),
75+ expectedCodeParamTypes,
76+ expectedCodeReturnType,
77+ body,
78+ optimize,
79+ print,
80+ )
81+
82+ private [this ] def Impl [F : TypeInfo , Mixin ](
83+ ctx : ExecuteContext ,
84+ params : IndexedSeq [(Name , EmitParamType )],
85+ aggSigs : Option [Array [AggStateSig ]],
86+ expectedCodeParamTypes : IndexedSeq [TypeInfo [_]],
87+ expectedCodeReturnType : TypeInfo [_],
88+ body : IR ,
89+ optimize : Boolean ,
90+ print : Option [PrintWriter ],
91+ )(implicit
92+ E : Enclosing ,
93+ N : sourcecode.Name ,
94+ ): (Option [SingleCodeType ], (HailClassLoader , FS , HailTaskContext , Region ) => F with Mixin ) =
4595 ctx.time {
4696 val normalizedBody = NormalizeNames (ctx, body, allowFreeVariables = true )
4797 ctx.CodeCache .getOrElseUpdate(
48- CodeCacheKey (FastSeq (), params.map { case (n, pt) => (n, pt) }, normalizedBody), {
49- var ir = body
50- ir = Subst (
51- ir,
52- BindingEnv (params
53- .zipWithIndex
54- .foldLeft(Env .empty[IR ]) { case (e, ((n, t), i)) => e.bind(n, In (i, t)) }),
98+ CodeCacheKey (aggSigs.getOrElse(Array .empty).toFastSeq, params, normalizedBody), {
99+ var ir = Subst (
100+ body,
101+ BindingEnv (Env .fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In (i, t) })),
55102 )
56103 ir = LoweringPipeline .compileLowerer(optimize)(ctx, ir).asInstanceOf [IR ].noSharing(ctx)
57-
58104 TypeCheck (ctx, ir)
59105
60106 val fb = EmitFunctionBuilder [F ](
61107 ctx,
62- " Compiled" ,
63- CodeParamType (typeInfo[Region ]) +: params.map { case (_, pt) =>
64- pt
65- },
108+ N .value,
109+ CodeParamType (typeInfo[Region ]) +: params.map(_._2),
66110 CodeParamType (SingleCodeType .typeInfoFromType(ir.typ)),
67111 Some (" Emit.scala" ),
68112 )
@@ -83,65 +127,10 @@ object Compile {
83127 )
84128
85129 val emitContext = EmitContext .analyze(ctx, ir)
86- val rt = Emit (emitContext, ir, fb, expectedCodeReturnType, params.length)
130+ val rt = Emit (emitContext, ir, fb, expectedCodeReturnType, params.length, aggSigs )
87131 CompiledFunction (rt, fb.resultWithIndex(print))
88132 },
89- ).asInstanceOf [CompiledFunction [F ]].tuple
90- }
91- }
92-
93- object CompileWithAggregators {
94- def apply [F : TypeInfo ](
95- ctx : ExecuteContext ,
96- aggSigs : Array [AggStateSig ],
97- params : IndexedSeq [(Name , EmitParamType )],
98- expectedCodeParamTypes : IndexedSeq [TypeInfo [_]],
99- expectedCodeReturnType : TypeInfo [_],
100- body : IR ,
101- optimize : Boolean = true ,
102- ): (
103- Option [SingleCodeType ],
104- (HailClassLoader , FS , HailTaskContext , Region ) => (F with FunctionWithAggRegion ),
105- ) =
106- ctx.time {
107- val normalizedBody = NormalizeNames (ctx, body, allowFreeVariables = true )
108- ctx.CodeCache .getOrElseUpdate(
109- CodeCacheKey (aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody), {
110- var ir = body
111- ir = Subst (
112- ir,
113- BindingEnv (params
114- .zipWithIndex
115- .foldLeft(Env .empty[IR ]) { case (e, ((n, t), i)) => e.bind(n, In (i, t)) }),
116- )
117- ir =
118- LoweringPipeline .compileLowerer(optimize).apply(ctx, ir).asInstanceOf [IR ].noSharing(ctx)
119-
120- TypeCheck (
121- ctx,
122- ir,
123- BindingEnv (Env .fromSeq[Type ](params.map { case (name, t) => name -> t.virtualType })),
124- )
125-
126- val fb = EmitFunctionBuilder [F with FunctionWithAggRegion ](
127- ctx,
128- " CompiledWithAggs" ,
129- CodeParamType (typeInfo[Region ]) +: params.map { case (_, pt) => pt },
130- SingleCodeType .typeInfoFromType(ir.typ),
131- Some (" Emit.scala" ),
132- )
133-
134- /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
135- * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c)
136- * } }
137- *
138- * visit(ir) } */
139-
140- val emitContext = EmitContext .analyze(ctx, ir)
141- val rt = Emit (emitContext, ir, fb, expectedCodeReturnType, params.length, Some (aggSigs))
142- CompiledFunction (rt, fb.resultWithIndex())
143- },
144- ).asInstanceOf [CompiledFunction [F with FunctionWithAggRegion ]].tuple
133+ ).asInstanceOf [CompiledFunction [F with Mixin ]].tuple
145134 }
146135}
147136
0 commit comments