Skip to content

Commit 1abe76b

Browse files
committed
[query] Lowering + Optimisation with implict timing context
1 parent b9629d0 commit 1abe76b

28 files changed

+2171
-1969
lines changed

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ object LocalBackend extends Backend {
144144
Validate(ir)
145145
val queryID = Backend.nextID()
146146
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
147-
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
147+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
148148
val res = _jvmLowerAndExecute(ctx, ir)
149149
log.info(s"finished execution of query $queryID")
150150
res

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class ServiceBackend(
285285
Validate(ir)
286286
val queryID = Backend.nextID()
287287
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
288-
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
288+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
289289
val res = _jvmLowerAndExecute(ctx, ir)
290290
log.info(s"finished execution of query $queryID")
291291
res

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class SparkBackend(val sc: SparkContext) extends Backend {
447447
ctx.time {
448448
TypeCheck(ctx, ir)
449449
Validate(ir)
450-
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
450+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
451451
try {
452452
val lowerTable = ctx.flags.get("lower") != null
453453
val lowerBM = ctx.flags.get("lower_bm") != null

hail/hail/src/is/hail/expr/ir/BaseIR.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ abstract class BaseIR {
3333
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
3434
var mark: Int = 0
3535

36-
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
37-
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
38-
* names */
39-
NormalizeNames(ctx, this, allowFreeVariables = true) ==
40-
NormalizeNames(ctx, other, allowFreeVariables = true)
36+
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = {
37+
// FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names
38+
val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true)
39+
normalize(ctx, this) == normalize(ctx, other)
40+
}
4141

4242
def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
4343
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray

hail/hail/src/is/hail/expr/ir/Compilable.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ object InterpretableButNotCompilable {
2020
case _: MatrixToValueApply => true
2121
case _: BlockMatrixToValueApply => true
2222
case _: BlockMatrixCollect => true
23-
case _: BlockMatrixToTableApply => true
2423
case _ => false
2524
}
2625
}
@@ -44,7 +43,6 @@ object Compilable {
4443
case _: TableToValueApply => false
4544
case _: MatrixToValueApply => false
4645
case _: BlockMatrixToValueApply => false
47-
case _: BlockMatrixToTableApply => false
4846
case _: RelationalRef => false
4947
case _: RelationalLet => false
5048
case _ => true

hail/hail/src/is/hail/expr/ir/Compile.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ object compile {
9494
N: sourcecode.Name,
9595
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
9696
ctx.time {
97-
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
97+
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
9898
ctx.CodeCache.getOrElseUpdate(
9999
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
100100
var ir = Subst(

hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,52 +25,54 @@ object ExtractIntervalFilters {
2525

2626
val MAX_LITERAL_SIZE = 4096
2727

28-
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = {
29-
MapIR.mapBaseIR(
30-
ir0,
31-
(ir: BaseIR) => {
32-
(
33-
ir match {
34-
case TableFilter(child, pred) =>
35-
extractPartitionFilters(
36-
ctx,
37-
pred,
38-
Ref(TableIR.rowName, child.typ.rowType),
39-
child.typ.key,
40-
)
41-
.map { case (newCond, intervals) =>
42-
log.info(
43-
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
44-
s"Intervals: ${intervals.mkString(", ")}\n " +
45-
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
46-
)
47-
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
48-
}
49-
case MatrixFilterRows(child, pred) => extractPartitionFilters(
50-
ctx,
51-
pred,
52-
Ref(MatrixIR.rowName, child.typ.rowType),
53-
child.typ.rowKey,
54-
).map { case (newCond, intervals) =>
28+
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR =
29+
ctx.time {
30+
MapIR.mapBaseIR(
31+
ir0,
32+
{
33+
case ir @ TableFilter(child, pred) =>
34+
extractPartitionFilters(
35+
ctx,
36+
pred,
37+
Ref(TableIR.rowName, child.typ.rowType),
38+
child.typ.key,
39+
)
40+
.map { case (newCond, intervals) =>
41+
log.info(
42+
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
43+
s"Intervals: ${intervals.mkString(", ")}\n " +
44+
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
45+
)
46+
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
47+
}
48+
.getOrElse(ir)
49+
50+
case ir @ MatrixFilterRows(child, pred) =>
51+
extractPartitionFilters(
52+
ctx,
53+
pred,
54+
Ref(MatrixIR.rowName, child.typ.rowType),
55+
child.typ.rowKey,
56+
)
57+
.map { case (newCond, intervals) =>
5558
log.info(
5659
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
5760
s"Intervals: ${intervals.mkString(", ")}\n " +
5861
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
5962
)
6063
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
6164
}
65+
.getOrElse(ir)
6266

63-
case _ => None
64-
}
65-
).getOrElse(ir)
66-
},
67-
)
68-
}
67+
case ir => ir
68+
},
69+
)
70+
}
6971

7072
def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String])
71-
: Option[(IR, IndexedSeq[Interval])] = {
73+
: Option[(IR, IndexedSeq[Interval])] =
7274
if (key.isEmpty) None
73-
else {
75+
else ctx.time {
7476
val extract =
7577
new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key))
7678
val trueSet = extract.analyze(cond, ref.name)
@@ -82,7 +84,6 @@ object ExtractIntervalFilters {
8284
Some((extract.rewrite(cond, rw), trueSet))
8385
}
8486
}
85-
}
8687

8788
def liftPosIntervalsToLocus(pos: IndexedSeq[Interval], rg: ReferenceGenome, ctx: ExecuteContext)
8889
: IndexedSeq[Interval] = {

hail/hail/src/is/hail/expr/ir/FoldConstants.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import is.hail.utils.HailException
77

88
object FoldConstants {
99
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
10-
ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir)))
10+
ctx.time {
11+
ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir)))
12+
}
1113

1214
private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR =
1315
RewriteBottomUp(

hail/hail/src/is/hail/expr/ir/ForwardLets.scala

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@ import is.hail.utils.BoxedArrayBuilder
88
import scala.collection.Set
99

1010
object ForwardLets {
11-
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
12-
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
13-
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
14-
val nestingDepth = NestingDepth(ir1)
15-
16-
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {
11+
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
12+
ctx.time {
13+
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
14+
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
15+
val nestingDepth = NestingDepth(ctx, ir1)
1716

1817
def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int)
19-
: Boolean = {
18+
: Boolean =
2019
IsPure(value) && (
2120
value.isInstanceOf[Ref] ||
2221
value.isInstanceOf[In] ||
@@ -28,45 +27,44 @@ object ForwardLets {
2827
!ContainsAgg(value)) &&
2928
!ContainsAggIntermediate(value)
3029
)
31-
}
3230

33-
ir match {
34-
case l: Block =>
35-
val keep = new BoxedArrayBuilder[Binding]
36-
val refs = uses(l)
37-
val newEnv = l.bindings.foldLeft(env) {
38-
case (env, Binding(name, value, scope)) =>
39-
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
40-
if (
41-
rewriteValue.typ != TVoid
42-
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
43-
) {
44-
env.bindInScope(name, rewriteValue, scope)
45-
} else {
46-
keep += Binding(name, rewriteValue, scope)
47-
env
48-
}
49-
}
31+
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR =
32+
ir match {
33+
case l: Block =>
34+
val keep = new BoxedArrayBuilder[Binding]
35+
val refs = uses(l)
36+
val newEnv = l.bindings.foldLeft(env) {
37+
case (env, Binding(name, value, scope)) =>
38+
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
39+
if (
40+
rewriteValue.typ != TVoid
41+
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
42+
) {
43+
env.bindInScope(name, rewriteValue, scope)
44+
} else {
45+
keep += Binding(name, rewriteValue, scope)
46+
env
47+
}
48+
}
5049

51-
val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
52-
if (keep.isEmpty) newBody
53-
else Block(keep.result(), newBody)
50+
val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
51+
if (keep.isEmpty) newBody
52+
else Block(keep.result(), newBody)
5453

55-
case x @ Ref(name, _) =>
56-
env.eval
57-
.lookupOption(name)
58-
.map { forwarded =>
59-
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
60-
else forwarded
61-
}
62-
.getOrElse(x)
63-
case _ =>
64-
ir.mapChildrenWithIndex((ir1, i) =>
65-
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
66-
)
67-
}
68-
}
54+
case x @ Ref(name, _) =>
55+
env.eval
56+
.lookupOption(name)
57+
.map { forwarded =>
58+
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
59+
else forwarded
60+
}
61+
.getOrElse(x)
62+
case _ =>
63+
ir.mapChildrenWithIndex((ir1, i) =>
64+
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
65+
)
66+
}
6967

70-
rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
71-
}
68+
rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
69+
}
7270
}

0 commit comments

Comments
 (0)