Skip to content

Commit 23be511

Browse files
committed
[query] Remove lookupOrCompileCachedFunction from Backend interface
1 parent 97d43aa commit 23be511

File tree

18 files changed

+156
-167
lines changed

18 files changed

+156
-167
lines changed

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ import is.hail.asm4s._
44
import is.hail.backend.Backend.jsonToBytes
55
import is.hail.backend.spark.SparkBackend
66
import is.hail.expr.ir.{
7-
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
8-
SortField, TableIR, TableReader,
7+
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
98
}
109
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
1110
import is.hail.io.{BufferSpec, TypedCodecSpec}
@@ -92,9 +91,6 @@ abstract class Backend extends Closeable {
9291

9392
def shouldCacheQueryInfo: Boolean = true
9493

95-
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
96-
: CompiledFunction[T]
97-
9894
def lowerDistributedSort(
9995
ctx: ExecuteContext,
10096
stage: TableStage,
@@ -193,23 +189,3 @@ abstract class Backend extends Closeable {
193189

194190
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
195191
}
196-
197-
trait BackendWithCodeCache {
198-
private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50)
199-
200-
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
201-
: CompiledFunction[T] = {
202-
codeCache.get(k) match {
203-
case Some(v) => v.asInstanceOf[CompiledFunction[T]]
204-
case None =>
205-
val compiledFunction = f
206-
codeCache += ((k, compiledFunction))
207-
compiledFunction
208-
}
209-
}
210-
}
211-
212-
trait BackendWithNoCodeCache {
213-
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
214-
: CompiledFunction[T] = f
215-
}

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
7+
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
78
import is.hail.expr.ir.lowering.IrMetadata
89
import is.hail.io.fs.FS
910
import is.hail.linalg.BlockMatrix
@@ -73,6 +74,7 @@ object ExecuteContext {
7374
backendContext: BackendContext,
7475
irMetadata: IrMetadata,
7576
blockMatrixCache: mutable.Map[String, BlockMatrix],
77+
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
7678
)(
7779
f: ExecuteContext => T
7880
): T = {
@@ -92,6 +94,7 @@ object ExecuteContext {
9294
backendContext,
9395
irMetadata,
9496
blockMatrixCache,
97+
codeCache,
9598
))(f(_))
9699
}
97100
}
@@ -122,6 +125,7 @@ class ExecuteContext(
122125
val backendContext: BackendContext,
123126
val irMetadata: IrMetadata,
124127
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
128+
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
125129
) extends Closeable {
126130

127131
val rngNonce: Long =
@@ -189,6 +193,7 @@ class ExecuteContext(
189193
backendContext: BackendContext = this.backendContext,
190194
irMetadata: IrMetadata = this.irMetadata,
191195
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
196+
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
192197
)(
193198
f: ExecuteContext => A
194199
): A =
@@ -206,5 +211,6 @@ class ExecuteContext(
206211
backendContext,
207212
irMetadata,
208213
blockMatrixCache,
214+
codeCache,
209215
))(f)
210216
}

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
1010
import is.hail.expr.ir.analyses.SemanticHash
11+
import is.hail.expr.ir.compile.Compile
1112
import is.hail.expr.ir.defs.MakeTuple
1213
import is.hail.expr.ir.lowering._
1314
import is.hail.io.fs._
@@ -84,13 +85,14 @@ object LocalBackend {
8485
class LocalBackend(
8586
val tmpdir: String,
8687
override val references: mutable.Map[String, ReferenceGenome],
87-
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
88+
) extends Backend with Py4JBackendExtensions {
8889

8990
override def backend: Backend = this
9091
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
9192
override def longLifeTempFileManager: TempFileManager = null
9293

93-
private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
94+
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
95+
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
9496

9597
// flags can be set after construction from python
9698
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -114,6 +116,7 @@ class LocalBackend(
114116
},
115117
new IrMetadata(),
116118
ImmutableMap.empty,
119+
codeCache,
117120
)(f)
118121
}
119122

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import is.hail.asm4s._
66
import is.hail.backend._
77
import is.hail.expr.Validate
88
import is.hail.expr.ir.{
9-
Compile, IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck,
9+
IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck,
1010
}
1111
import is.hail.expr.ir.analyses.SemanticHash
12+
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.defs.MakeTuple
1314
import is.hail.expr.ir.functions.IRFunctionRegistry
1415
import is.hail.expr.ir.lowering._
@@ -51,7 +52,6 @@ class ServiceBackendContext(
5152
) extends BackendContext with Serializable {}
5253

5354
object ServiceBackend {
54-
private val log = Logger.getLogger(getClass.getName())
5555

5656
def apply(
5757
jarLocation: String,
@@ -130,8 +130,7 @@ class ServiceBackend(
130130
val fs: FS,
131131
val serviceBackendContext: ServiceBackendContext,
132132
val scratchDir: String,
133-
) extends Backend with BackendWithNoCodeCache {
134-
import ServiceBackend.log
133+
) extends Backend with Logging {
135134

136135
private[this] var stageCount = 0
137136
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
@@ -393,6 +392,7 @@ class ServiceBackend(
393392
serviceBackendContext,
394393
new IrMetadata(),
395394
ImmutableMap.empty,
395+
mutable.Map.empty,
396396
)(f)
397397
}
398398

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations._
55
import is.hail.asm4s._
66
import is.hail.backend._
7-
import is.hail.backend.caching.BlockMatrixCache
87
import is.hail.backend.py4j.Py4JBackendExtensions
98
import is.hail.expr.Validate
109
import is.hail.expr.ir._
1110
import is.hail.expr.ir.analyses.SemanticHash
11+
import is.hail.expr.ir.compile.Compile
1212
import is.hail.expr.ir.defs.MakeTuple
1313
import is.hail.expr.ir.lowering._
1414
import is.hail.io.{BufferSpec, TypedCodecSpec}
1515
import is.hail.io.fs._
16+
import is.hail.linalg.BlockMatrix
1617
import is.hail.rvd.RVD
1718
import is.hail.types._
1819
import is.hail.types.physical.{PStruct, PTuple}
@@ -26,9 +27,10 @@ import scala.collection.mutable.ArrayBuffer
2627
import scala.concurrent.ExecutionException
2728
import scala.reflect.ClassTag
2829
import scala.util.control.NonFatal
30+
2931
import java.io.PrintWriter
32+
3033
import com.fasterxml.jackson.core.StreamReadConstraints
31-
import is.hail.linalg.BlockMatrix
3234
import org.apache.hadoop
3335
import org.apache.hadoop.conf.Configuration
3436
import org.apache.spark._
@@ -320,7 +322,7 @@ class SparkBackend(
320322
override val references: mutable.Map[String, ReferenceGenome],
321323
gcsRequesterPaysProject: String,
322324
gcsRequesterPaysBuckets: String,
323-
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
325+
) extends Backend with Py4JBackendExtensions {
324326

325327
assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null)
326328
lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
@@ -352,6 +354,7 @@ class SparkBackend(
352354
new OwningTempFileManager(fs)
353355

354356
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
357+
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
355358

356359
def createExecuteContextForTests(
357360
timer: ExecutionTimer,
@@ -375,6 +378,7 @@ class SparkBackend(
375378
},
376379
new IrMetadata(),
377380
ImmutableMap.empty,
381+
mutable.Map.empty,
378382
)
379383

380384
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -395,6 +399,7 @@ class SparkBackend(
395399
},
396400
new IrMetadata(),
397401
bmCache,
402+
codeCache,
398403
)(f)
399404
}
400405

0 commit comments

Comments
 (0)