Skip to content

Commit 9615ece

Browse files
committed
[query] Move LoweredTableReaderCoercer into ExecuteContext
1 parent 0ce4a09 commit 9615ece

File tree

9 files changed

+383
-380
lines changed

9 files changed

+383
-380
lines changed

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ abstract class Backend extends Closeable {
8888
def asSpark(implicit E: Enclosing): SparkBackend =
8989
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")
9090

91-
def shouldCacheQueryInfo: Boolean = true
92-
9391
def lowerDistributedSort(
9492
ctx: ExecuteContext,
9593
stage: TableStage,

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
77
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
8+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
89
import is.hail.expr.ir.lowering.IrMetadata
910
import is.hail.io.fs.FS
1011
import is.hail.linalg.BlockMatrix
@@ -74,6 +75,7 @@ object ExecuteContext {
7475
blockMatrixCache: mutable.Map[String, BlockMatrix],
7576
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
7677
irCache: mutable.Map[Int, BaseIR],
78+
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
7779
)(
7880
f: ExecuteContext => T
7981
): T = {
@@ -95,6 +97,7 @@ object ExecuteContext {
9597
blockMatrixCache,
9698
codeCache,
9799
irCache,
100+
coercerCache,
98101
))(f(_))
99102
}
100103
}
@@ -127,6 +130,7 @@ class ExecuteContext(
127130
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
128131
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
129132
val IrCache: mutable.Map[Int, BaseIR],
133+
val CoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
130134
) extends Closeable {
131135

132136
val rngNonce: Long =
@@ -199,6 +203,7 @@ class ExecuteContext(
199203
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
200204
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
201205
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
206+
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.CoercerCache,
202207
)(
203208
f: ExecuteContext => A
204209
): A =
@@ -218,5 +223,6 @@ class ExecuteContext(
218223
blockMatrixCache,
219224
codeCache,
220225
irCache,
226+
coercerCache,
221227
))(f)
222228
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package is.hail.backend
2+
3+
import scala.collection.mutable
4+
5+
package object caching {
6+
private[this] object NoCachingInstance extends mutable.AbstractMap[Any, Any] {
7+
override def +=(kv: (Any, Any)): NoCachingInstance.this.type = this
8+
override def -=(key: Any): NoCachingInstance.this.type = this
9+
override def get(key: Any): Option[Any] = None
10+
override def iterator: Iterator[(Any, Any)] = Iterator.empty
11+
override def getOrElseUpdate(key: Any, op: => Any): Any = op
12+
}
13+
14+
def NoCaching[K, V]: mutable.Map[K, V] =
15+
NoCachingInstance.asInstanceOf[mutable.Map[K, V]]
16+
}

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import is.hail.backend._
77
import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
10+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1011
import is.hail.expr.ir.analyses.SemanticHash
1112
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.lowering._
@@ -81,6 +82,7 @@ class LocalBackend(
8182
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
8283
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
8384
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
85+
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)
8486

8587
// flags can be set after construction from python
8688
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -106,6 +108,7 @@ class LocalBackend(
106108
ImmutableMap.empty,
107109
codeCache,
108110
persistedIR,
111+
coercerCache,
109112
)(f)
110113
}
111114

hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package is.hail.backend.py4j
33
import is.hail.HailFeatureFlags
44
import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager}
55
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
6-
import is.hail.expr.ir.{BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue}
6+
import is.hail.expr.ir.{
7+
BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR,
8+
MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue,
9+
}
710
import is.hail.expr.ir.IRParser.parseType
811
import is.hail.expr.ir.functions.IRFunctionRegistry
912
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags}
44
import is.hail.annotations._
55
import is.hail.asm4s._
66
import is.hail.backend._
7+
import is.hail.backend.caching.NoCaching
78
import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections
89
import is.hail.expr.Validate
910
import is.hail.expr.ir.{
@@ -63,8 +64,6 @@ class ServiceBackend(
6364
private[this] var stageCount = 0
6465
private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections)
6566

66-
override def shouldCacheQueryInfo: Boolean = false
67-
6867
def defaultParallelism: Int = 4
6968

7069
def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = {
@@ -316,9 +315,10 @@ class ServiceBackend(
316315
),
317316
new IrMetadata(),
318317
references,
319-
ImmutableMap.empty,
320-
mutable.Map.empty,
321-
ImmutableMap.empty,
318+
NoCaching,
319+
NoCaching,
320+
NoCaching,
321+
NoCaching,
322322
)(f)
323323
}
324324
}

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations._
55
import is.hail.asm4s._
66
import is.hail.backend._
7-
import is.hail.backend.caching.BlockMatrixCache
7+
import is.hail.backend.caching.{BlockMatrixCache, NoCaching}
88
import is.hail.backend.py4j.Py4JBackendExtensions
99
import is.hail.expr.Validate
1010
import is.hail.expr.ir._
11+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1112
import is.hail.expr.ir.analyses.SemanticHash
1213
import is.hail.expr.ir.compile.Compile
1314
import is.hail.expr.ir.lowering._
@@ -343,6 +344,7 @@ class SparkBackend(
343344
private[this] val bmCache = new BlockMatrixCache()
344345
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
345346
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
347+
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)
346348

347349
def createExecuteContextForTests(
348350
timer: ExecutionTimer,
@@ -365,8 +367,9 @@ class SparkBackend(
365367
new IrMetadata(),
366368
references,
367369
ImmutableMap.empty,
368-
mutable.Map.empty,
370+
NoCaching,
369371
ImmutableMap.empty,
372+
NoCaching,
370373
)
371374

372375
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) =
@@ -393,6 +396,7 @@ class SparkBackend(
393396
bmCache,
394397
codeCache,
395398
persistedIr,
399+
coercerCache,
396400
)(f)
397401
}
398402

hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package is.hail.expr.ir
33
import is.hail.annotations.Region
44
import is.hail.asm4s._
55
import is.hail.backend.ExecuteContext
6+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
67
import is.hail.expr.ir.functions.UtilFunctions
78
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
89
import is.hail.expr.ir.streams.StreamProducer
@@ -143,16 +144,6 @@ class PartitionIteratorLongReader(
143144
)
144145
}
145146

146-
abstract class LoweredTableReaderCoercer {
147-
def coerce(
148-
ctx: ExecuteContext,
149-
globals: IR,
150-
contextType: Type,
151-
contexts: IndexedSeq[Any],
152-
body: IR => IR,
153-
): TableStage
154-
}
155-
156147
class GenericTableValue(
157148
val fullTableType: TableType,
158149
val uidFieldName: String,
@@ -168,12 +159,11 @@ class GenericTableValue(
168159
assert(contextType.hasField("partitionIndex"))
169160
assert(contextType.fieldType("partitionIndex") == TInt32)
170161

171-
private var ltrCoercer: LoweredTableReaderCoercer = _
172-
173162
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any)
174-
: LoweredTableReaderCoercer = {
175-
if (ltrCoercer == null) {
176-
ltrCoercer = LoweredTableReader.makeCoercer(
163+
: LoweredTableReaderCoercer =
164+
ctx.CoercerCache.getOrElseUpdate(
165+
(1, contextType, fullTableType.key, cacheKey),
166+
LoweredTableReader.makeCoercer(
177167
ctx,
178168
fullTableType.key,
179169
1,
@@ -184,11 +174,8 @@ class GenericTableValue(
184174
bodyPType,
185175
body,
186176
context,
187-
cacheKey,
188-
)
189-
}
190-
ltrCoercer
191-
}
177+
),
178+
)
192179

193180
def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any)
194181
: TableStage = {
@@ -217,11 +204,13 @@ class GenericTableValue(
217204
val contextsIR = ToStream(Literal(TArray(contextType), contexts))
218205
TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody)
219206
} else {
220-
getLTVCoercer(ctx, context, cacheKey).coerce(
207+
getLTVCoercer(ctx, context, cacheKey)(
221208
ctx,
222209
globalsIR,
223-
contextType, contexts,
224-
requestedBody)
210+
contextType,
211+
contexts,
212+
requestedBody,
213+
)
225214
}
226215
}
227216
}

0 commit comments

Comments
 (0)