Skip to content

Commit 2d0eb42

Browse files
committed
[query] Lowering + Optimisation with implict timing context
1 parent d2968b0 commit 2d0eb42

22 files changed

+705
-702
lines changed

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ object LocalBackend extends Backend {
130130
Validate(ir)
131131
val queryID = Backend.nextID()
132132
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
133-
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
133+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
134134
val res = _jvmLowerAndExecute(ctx, ir)
135135
log.info(s"finished execution of query $queryID")
136136
res

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class ServiceBackend(
242242
Validate(ir)
243243
val queryID = Backend.nextID()
244244
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
245-
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
245+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
246246
val res = _jvmLowerAndExecute(ctx, ir)
247247
log.info(s"finished execution of query $queryID")
248248
res

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class SparkBackend(val sc: SparkContext) extends Backend {
424424
ctx.time {
425425
TypeCheck(ctx, ir)
426426
Validate(ir)
427-
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
427+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
428428
try {
429429
val lowerTable = ctx.flags.get("lower") != null
430430
val lowerBM = ctx.flags.get("lower_bm") != null

hail/src/main/scala/is/hail/expr/ir/BaseIR.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ abstract class BaseIR {
3333
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
3434
var mark: Int = 0
3535

36-
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
37-
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
38-
* names */
39-
NormalizeNames(ctx, this, allowFreeVariables = true) ==
40-
NormalizeNames(ctx, other, allowFreeVariables = true)
36+
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = {
37+
// FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names
38+
val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true)
39+
normalize(ctx, this) == normalize(ctx, other)
40+
}
4141

4242
def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
4343
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray

hail/src/main/scala/is/hail/expr/ir/Compile.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ object compile {
9393
N: sourcecode.Name,
9494
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
9595
ctx.time {
96-
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
96+
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
9797
ctx.CodeCache.getOrElseUpdate(
9898
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
9999
var ir = Subst(

hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,52 +24,53 @@ object ExtractIntervalFilters {
2424

2525
val MAX_LITERAL_SIZE = 4096
2626

27-
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = {
28-
MapIR.mapBaseIR(
29-
ir0,
30-
(ir: BaseIR) => {
31-
(
32-
ir match {
33-
case TableFilter(child, pred) =>
34-
extractPartitionFilters(
35-
ctx,
36-
pred,
37-
Ref(TableIR.rowName, child.typ.rowType),
38-
child.typ.key,
39-
)
40-
.map { case (newCond, intervals) =>
27+
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR =
28+
ctx.time {
29+
MapIR.mapBaseIR(
30+
ir0,
31+
ir =>
32+
(
33+
ir match {
34+
case TableFilter(child, pred) =>
35+
extractPartitionFilters(
36+
ctx,
37+
pred,
38+
Ref(TableIR.rowName, child.typ.rowType),
39+
child.typ.key,
40+
)
41+
.map { case (newCond, intervals) =>
42+
log.info(
43+
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
44+
s"Intervals: ${intervals.mkString(", ")}\n " +
45+
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
46+
)
47+
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
48+
}
49+
case MatrixFilterRows(child, pred) =>
50+
extractPartitionFilters(
51+
ctx,
52+
pred,
53+
Ref(MatrixIR.rowName, child.typ.rowType),
54+
child.typ.rowKey,
55+
).map { case (newCond, intervals) =>
4156
log.info(
42-
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
57+
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
4358
s"Intervals: ${intervals.mkString(", ")}\n " +
4459
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
4560
)
46-
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
61+
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
4762
}
48-
case MatrixFilterRows(child, pred) => extractPartitionFilters(
49-
ctx,
50-
pred,
51-
Ref(MatrixIR.rowName, child.typ.rowType),
52-
child.typ.rowKey,
53-
).map { case (newCond, intervals) =>
54-
log.info(
55-
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
56-
s"Intervals: ${intervals.mkString(", ")}\n " +
57-
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
58-
)
59-
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
60-
}
6163

62-
case _ => None
63-
}
64-
).getOrElse(ir)
65-
},
66-
)
67-
}
64+
case _ => None
65+
}
66+
).getOrElse(ir),
67+
)
68+
}
6869

6970
def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String])
7071
: Option[(IR, IndexedSeq[Interval])] = {
7172
if (key.isEmpty) None
72-
else {
73+
else ctx.time {
7374
val extract =
7475
new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key))
7576
val trueSet = extract.analyze(cond, ref.name)

hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import is.hail.utils.HailException
66

77
object FoldConstants {
88
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
9-
ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir)))
9+
ctx.time {
10+
ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir)))
11+
}
1012

1113
private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR =
1214
RewriteBottomUp(

hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@ package is.hail.expr.ir
22

33
import is.hail.backend.ExecuteContext
44
import is.hail.types.virtual.TVoid
5-
import is.hail.utils.BoxedArrayBuilder
5+
import is.hail.utils.{fatal, BoxedArrayBuilder}
66

77
import scala.collection.Set
8+
import scala.util.control.NonFatal
89

910
object ForwardLets {
10-
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
11-
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
12-
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
13-
val nestingDepth = NestingDepth(ir1)
14-
15-
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {
11+
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
12+
ctx.time {
13+
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
14+
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
15+
val nestingDepth = NestingDepth(ctx, ir1)
1616

1717
def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int)
18-
: Boolean = {
18+
: Boolean =
1919
IsPure(value) && (
2020
value.isInstanceOf[Ref] ||
2121
value.isInstanceOf[In] ||
@@ -27,45 +27,56 @@ object ForwardLets {
2727
!ContainsAgg(value)) &&
2828
!ContainsAggIntermediate(value)
2929
)
30-
}
3130

32-
ir match {
33-
case l: Block =>
34-
val keep = new BoxedArrayBuilder[Binding]
35-
val refs = uses(l)
36-
val newEnv = l.bindings.foldLeft(env) {
37-
case (env, Binding(name, value, scope)) =>
38-
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
39-
if (
40-
rewriteValue.typ != TVoid
41-
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
42-
) {
43-
env.bindInScope(name, rewriteValue, scope)
44-
} else {
45-
keep += Binding(name, rewriteValue, scope)
46-
env
31+
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR =
32+
ir match {
33+
case l: Block =>
34+
val keep = new BoxedArrayBuilder[Binding]
35+
val refs = uses(l)
36+
val newEnv = l.bindings.foldLeft(env) {
37+
case (env, Binding(name, value, scope)) =>
38+
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
39+
if (
40+
rewriteValue.typ != TVoid
41+
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
42+
) {
43+
env.bindInScope(name, rewriteValue, scope)
44+
} else {
45+
keep += Binding(name, rewriteValue, scope)
46+
env
47+
}
48+
}
49+
50+
val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
51+
if (keep.isEmpty) newBody
52+
else Block(keep.result(), newBody)
53+
54+
case x @ Ref(name, _) =>
55+
env.eval
56+
.lookupOption(name)
57+
.map { forwarded =>
58+
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
59+
else forwarded
4760
}
48-
}
61+
.getOrElse(x)
62+
case _ =>
63+
ir.mapChildrenWithIndex((ir1, i) =>
64+
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
65+
)
66+
}
4967

50-
val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
51-
if (keep.isEmpty) newBody
52-
else Block(keep.result(), newBody)
68+
val ir = rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)))
5369

54-
case x @ Ref(name, _) =>
55-
env.eval
56-
.lookupOption(name)
57-
.map { forwarded =>
58-
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
59-
else forwarded
60-
}
61-
.getOrElse(x)
62-
case _ =>
63-
ir.mapChildrenWithIndex((ir1, i) =>
64-
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
70+
try
71+
TypeCheck(ctx, ir)
72+
catch {
73+
case NonFatal(e) =>
74+
fatal(
75+
s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}",
76+
e,
6577
)
6678
}
67-
}
6879

69-
rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
70-
}
80+
ir.asInstanceOf[T]
81+
}
7182
}

0 commit comments

Comments
 (0)