Skip to content

Commit df36f36

Browse files
committed
[query] Move LoweredTableReaderCoercer into ExecuteContext
1 parent 234a058 commit df36f36

File tree

7 files changed

+359
-376
lines changed

7 files changed

+359
-376
lines changed

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ abstract class Backend extends Closeable {
8585
def asSpark(op: String): SparkBackend =
8686
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")
8787

88-
def shouldCacheQueryInfo: Boolean = true
89-
9088
def lowerDistributedSort(
9189
ctx: ExecuteContext,
9290
stage: TableStage,

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
77
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
8+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
89
import is.hail.expr.ir.lowering.IrMetadata
910
import is.hail.io.fs.FS
1011
import is.hail.linalg.BlockMatrix
@@ -76,6 +77,7 @@ object ExecuteContext {
7677
blockMatrixCache: mutable.Map[String, BlockMatrix],
7778
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
7879
irCache: mutable.Map[Int, BaseIR],
80+
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
7981
)(
8082
f: ExecuteContext => T
8183
): T = {
@@ -97,6 +99,7 @@ object ExecuteContext {
9799
blockMatrixCache,
98100
codeCache,
99101
irCache,
102+
coercerCache,
100103
))(f(_))
101104
}
102105
}
@@ -129,6 +132,7 @@ class ExecuteContext(
129132
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
130133
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
131134
val IrCache: mutable.Map[Int, BaseIR],
135+
val CoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
132136
) extends Closeable {
133137

134138
val rngNonce: Long =
@@ -198,6 +202,7 @@ class ExecuteContext(
198202
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
199203
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
200204
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
205+
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.CoercerCache,
201206
)(
202207
f: ExecuteContext => A
203208
): A =
@@ -217,5 +222,6 @@ class ExecuteContext(
217222
blockMatrixCache,
218223
codeCache,
219224
irCache,
225+
coercerCache,
220226
))(f)
221227
}

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import is.hail.backend._
77
import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
10+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1011
import is.hail.expr.ir.analyses.SemanticHash
1112
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.lowering._
@@ -93,6 +94,7 @@ class LocalBackend(
9394
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
9495
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
9596
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
97+
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)
9698

9799
// flags can be set after construction from python
98100
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -118,6 +120,7 @@ class LocalBackend(
118120
ImmutableMap.empty,
119121
codeCache,
120122
persistedIR,
123+
coercerCache,
121124
)(f)
122125
}
123126

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ class ServiceBackend(
135135
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
136136
private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS)
137137

138-
override def shouldCacheQueryInfo: Boolean = false
139-
140138
def defaultParallelism: Int = 4
141139

142140
def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = {
@@ -391,7 +389,8 @@ class ServiceBackend(
391389
serviceBackendContext,
392390
new IrMetadata(),
393391
ImmutableMap.empty,
394-
mutable.Map.empty,
392+
ImmutableMap.empty,
393+
ImmutableMap.empty,
395394
ImmutableMap.empty,
396395
)(f)
397396
}

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import is.hail.backend.caching.BlockMatrixCache
88
import is.hail.backend.py4j.Py4JBackendExtensions
99
import is.hail.expr.Validate
1010
import is.hail.expr.ir._
11+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1112
import is.hail.expr.ir.analyses.SemanticHash
1213
import is.hail.expr.ir.compile.Compile
1314
import is.hail.expr.ir.lowering._
@@ -355,6 +356,7 @@ class SparkBackend(
355356
private[this] val bmCache = new BlockMatrixCache()
356357
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
357358
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
359+
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)
358360

359361
def createExecuteContextForTests(
360362
timer: ExecutionTimer,
@@ -378,7 +380,8 @@ class SparkBackend(
378380
},
379381
new IrMetadata(),
380382
ImmutableMap.empty,
381-
mutable.Map.empty,
383+
ImmutableMap.empty,
384+
ImmutableMap.empty,
382385
ImmutableMap.empty,
383386
)
384387

@@ -402,6 +405,7 @@ class SparkBackend(
402405
bmCache,
403406
codeCache,
404407
persistedIr,
408+
coercerCache,
405409
)(f)
406410
}
407411

hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package is.hail.expr.ir
33
import is.hail.annotations.Region
44
import is.hail.asm4s._
55
import is.hail.backend.ExecuteContext
6+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
67
import is.hail.expr.ir.functions.UtilFunctions
78
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
89
import is.hail.expr.ir.streams.StreamProducer
@@ -143,16 +144,6 @@ class PartitionIteratorLongReader(
143144
)
144145
}
145146

146-
abstract class LoweredTableReaderCoercer {
147-
def coerce(
148-
ctx: ExecuteContext,
149-
globals: IR,
150-
contextType: Type,
151-
contexts: IndexedSeq[Any],
152-
body: IR => IR,
153-
): TableStage
154-
}
155-
156147
class GenericTableValue(
157148
val fullTableType: TableType,
158149
val uidFieldName: String,
@@ -168,12 +159,11 @@ class GenericTableValue(
168159
assert(contextType.hasField("partitionIndex"))
169160
assert(contextType.fieldType("partitionIndex") == TInt32)
170161

171-
private var ltrCoercer: LoweredTableReaderCoercer = _
172-
173162
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any)
174-
: LoweredTableReaderCoercer = {
175-
if (ltrCoercer == null) {
176-
ltrCoercer = LoweredTableReader.makeCoercer(
163+
: LoweredTableReaderCoercer =
164+
ctx.CoercerCache.getOrElseUpdate(
165+
(1, contextType, fullTableType.key, cacheKey),
166+
LoweredTableReader.makeCoercer(
177167
ctx,
178168
fullTableType.key,
179169
1,
@@ -184,11 +174,8 @@ class GenericTableValue(
184174
bodyPType,
185175
body,
186176
context,
187-
cacheKey,
188-
)
189-
}
190-
ltrCoercer
191-
}
177+
),
178+
)
192179

193180
def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any)
194181
: TableStage = {
@@ -217,11 +204,13 @@ class GenericTableValue(
217204
val contextsIR = ToStream(Literal(TArray(contextType), contexts))
218205
TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody)
219206
} else {
220-
getLTVCoercer(ctx, context, cacheKey).coerce(
207+
getLTVCoercer(ctx, context, cacheKey)(
221208
ctx,
222209
globalsIR,
223-
contextType, contexts,
224-
requestedBody)
210+
contextType,
211+
contexts,
212+
requestedBody,
213+
)
225214
}
226215
}
227216
}

0 commit comments

Comments
 (0)