Skip to content

Commit 95646af

Browse files
committed
[query] Remove persistedIr from Backend interface
1 parent 8fd036a commit 95646af

File tree

12 files changed

+466
-500
lines changed

12 files changed

+466
-500
lines changed

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package is.hail.backend
33
import is.hail.asm4s._
44
import is.hail.backend.Backend.jsonToBytes
55
import is.hail.backend.spark.SparkBackend
6-
import is.hail.expr.ir.{
7-
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
8-
}
6+
import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader}
97
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
108
import is.hail.io.{BufferSpec, TypedCodecSpec}
119
import is.hail.io.fs._
@@ -18,7 +16,6 @@ import is.hail.types.virtual.TFloat64
1816
import is.hail.utils._
1917
import is.hail.variant.ReferenceGenome
2018

21-
import scala.collection.mutable
2219
import scala.reflect.ClassTag
2320

2421
import java.io._
@@ -80,8 +77,6 @@ abstract class Backend extends Closeable {
8077
StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()
8178
)
8279

83-
val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
84-
8580
def defaultParallelism: Int
8681

8782
def canExecuteParallelTasksOnDriver: Boolean = true
@@ -140,30 +135,30 @@ abstract class Backend extends Closeable {
140135
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T
141136

142137
final def valueType(s: String): Array[Byte] =
143-
jsonToBytes {
144-
withExecuteContext { ctx =>
145-
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
138+
withExecuteContext { ctx =>
139+
jsonToBytes {
140+
IRParser.parse_value_ir(ctx, s).typ.toJSON
146141
}
147142
}
148143

149144
final def tableType(s: String): Array[Byte] =
150-
jsonToBytes {
151-
withExecuteContext { ctx =>
152-
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
145+
withExecuteContext { ctx =>
146+
jsonToBytes {
147+
IRParser.parse_table_ir(ctx, s).typ.toJSON
153148
}
154149
}
155150

156151
final def matrixTableType(s: String): Array[Byte] =
157-
jsonToBytes {
158-
withExecuteContext { ctx =>
159-
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
152+
withExecuteContext { ctx =>
153+
jsonToBytes {
154+
IRParser.parse_matrix_ir(ctx, s).typ.toJSON
160155
}
161156
}
162157

163158
final def blockMatrixType(s: String): Array[Byte] =
164-
jsonToBytes {
165-
withExecuteContext { ctx =>
166-
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
159+
withExecuteContext { ctx =>
160+
jsonToBytes {
161+
IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON
167162
}
168163
}
169164

hail/src/main/scala/is/hail/backend/BackendServer.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package is.hail.backend
22

3-
import is.hail.expr.ir.{IRParser, IRParserEnvironment}
3+
import is.hail.expr.ir.IRParser
44
import is.hail.utils._
55

66
import scala.util.control.NonFatal
@@ -89,10 +89,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
8989
backend.withExecuteContext { ctx =>
9090
val (res, timings) = ExecutionTimer.time { timer =>
9191
ctx.local(timer = timer) { ctx =>
92-
val irData = IRParser.parse_value_ir(
93-
irStr,
94-
IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap),
95-
)
92+
val irData = IRParser.parse_value_ir(ctx, irStr)
9693
backend.execute(ctx, irData)
9794
}
9895
}

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
7-
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
7+
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
88
import is.hail.expr.ir.lowering.IrMetadata
99
import is.hail.io.fs.FS
1010
import is.hail.linalg.BlockMatrix
@@ -75,6 +75,7 @@ object ExecuteContext {
7575
irMetadata: IrMetadata,
7676
blockMatrixCache: mutable.Map[String, BlockMatrix],
7777
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
78+
irCache: mutable.Map[Int, BaseIR],
7879
)(
7980
f: ExecuteContext => T
8081
): T = {
@@ -95,6 +96,7 @@ object ExecuteContext {
9596
irMetadata,
9697
blockMatrixCache,
9798
codeCache,
99+
irCache,
98100
))(f(_))
99101
}
100102
}
@@ -126,6 +128,7 @@ class ExecuteContext(
126128
val irMetadata: IrMetadata,
127129
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
128130
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
131+
val IrCache: mutable.Map[Int, BaseIR],
129132
) extends Closeable {
130133

131134
val rngNonce: Long =
@@ -196,6 +199,7 @@ class ExecuteContext(
196199
irMetadata: IrMetadata = this.irMetadata,
197200
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
198201
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
202+
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
199203
)(
200204
f: ExecuteContext => A
201205
): A =
@@ -214,5 +218,6 @@ class ExecuteContext(
214218
irMetadata,
215219
blockMatrixCache,
216220
codeCache,
221+
irCache,
217222
))(f)
218223
}

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class LocalBackend(
7979

8080
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
8181
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
82+
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
8283

8384
// flags can be set after construction from python
8485
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -103,6 +104,7 @@ class LocalBackend(
103104
new IrMetadata(),
104105
ImmutableMap.empty,
105106
codeCache,
107+
persistedIR,
106108
)(f)
107109
}
108110

hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import is.hail.HailFeatureFlags
44
import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager}
55
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
66
import is.hail.expr.ir.{
7-
BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser,
8-
IRParserEnvironment, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, Name,
9-
NativeReaderOptions, TableIR, TableLiteral, TableValue,
7+
BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, Interpret,
8+
MatrixIR, MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral,
9+
TableValue,
1010
}
1111
import is.hail.expr.ir.IRParser.parseType
1212
import is.hail.expr.ir.functions.IRFunctionRegistry
@@ -34,7 +34,6 @@ import sourcecode.Enclosing
3434
trait Py4JBackendExtensions {
3535
def backend: Backend
3636
def references: mutable.Map[String, ReferenceGenome]
37-
def persistedIR: mutable.Map[Int, BaseIR]
3837
def flags: HailFeatureFlags
3938
def longLifeTempFileManager: TempFileManager
4039

@@ -54,14 +53,14 @@ trait Py4JBackendExtensions {
5453
irID
5554
}
5655

57-
private[this] def addJavaIR(ir: BaseIR): Int = {
56+
private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = {
5857
val id = nextIRID()
59-
persistedIR += (id -> ir)
58+
ctx.IrCache += (id -> ir)
6059
id
6160
}
6261

6362
def pyRemoveJavaIR(id: Int): Unit =
64-
persistedIR.remove(id)
63+
backend.withExecuteContext(_.IrCache.remove(id))
6564

6665
def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
6766
backend.withExecuteContext { ctx =>
@@ -118,7 +117,7 @@ trait Py4JBackendExtensions {
118117
argTypeStrs: java.util.ArrayList[String],
119118
returnType: String,
120119
bodyStr: String,
121-
): Unit = {
120+
): Unit =
122121
backend.withExecuteContext { ctx =>
123122
IRFunctionRegistry.registerIR(
124123
ctx,
@@ -130,17 +129,16 @@ trait Py4JBackendExtensions {
130129
bodyStr,
131130
)
132131
}
133-
}
134132

135133
def pyExecuteLiteral(irStr: String): Int =
136134
backend.withExecuteContext { ctx =>
137-
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
135+
val ir = IRParser.parse_value_ir(ctx, irStr)
138136
assert(ir.typ.isRealizable)
139137
backend.execute(ctx, ir) match {
140138
case Left(_) => throw new HailException("Can't create literal")
141139
case Right((pt, addr)) =>
142140
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
143-
addJavaIR(field)
141+
addJavaIR(ctx, field)
144142
}
145143
}
146144

@@ -159,14 +157,14 @@ trait Py4JBackendExtensions {
159157
),
160158
ctx.theHailClassLoader,
161159
)
162-
val id = addJavaIR(tir)
160+
val id = addJavaIR(ctx, tir)
163161
(id, JsonMethods.compact(tir.typ.toJSON))
164162
}
165163
}
166164

167165
def pyToDF(s: String): DataFrame =
168166
backend.withExecuteContext { ctx =>
169-
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
167+
val tir = IRParser.parse_table_ir(ctx, s)
170168
Interpret(tir, ctx).toDF()
171169
}
172170

@@ -231,27 +229,23 @@ trait Py4JBackendExtensions {
231229
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
232230
backend.withExecuteContext { ctx =>
233231
IRParser.parse_value_ir(
232+
ctx,
234233
s,
235-
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
236234
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
237235
Name(n) -> IRParser.parseType(t)
238236
}.toSeq: _*),
239237
)
240238
}
241239

242240
def parse_table_ir(s: String): TableIR =
243-
withExecuteContext(selfContainedExecution = false) { ctx =>
244-
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
245-
}
241+
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s))
246242

247243
def parse_matrix_ir(s: String): MatrixIR =
248-
withExecuteContext(selfContainedExecution = false) { ctx =>
249-
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
250-
}
244+
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s))
251245

252246
def parse_blockmatrix_ir(s: String): BlockMatrixIR =
253247
withExecuteContext(selfContainedExecution = false) { ctx =>
254-
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
248+
IRParser.parse_blockmatrix_ir(ctx, s)
255249
}
256250

257251
def loadReferencesFromDataset(path: String): Array[Byte] =

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ class ServiceBackend(
387387
new IrMetadata(),
388388
ImmutableMap.empty,
389389
mutable.Map.empty,
390+
ImmutableMap.empty,
390391
)(f)
391392
}
392393

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ class SparkBackend(
341341

342342
private[this] val bmCache = new BlockMatrixCache()
343343
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
344+
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
344345

345346
def createExecuteContextForTests(
346347
timer: ExecutionTimer,
@@ -365,6 +366,7 @@ class SparkBackend(
365366
new IrMetadata(),
366367
ImmutableMap.empty,
367368
mutable.Map.empty,
369+
ImmutableMap.empty,
368370
)
369371

370372
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -386,6 +388,7 @@ class SparkBackend(
386388
new IrMetadata(),
387389
bmCache,
388390
codeCache,
391+
persistedIr,
389392
)(f)
390393
}
391394

hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ case class MatrixLiteral(typ: MatrixType, tl: TableLiteral) extends MatrixIR {
113113
}
114114

115115
object MatrixReader {
116-
def fromJson(env: IRParserEnvironment, jv: JValue): MatrixReader = {
116+
def fromJson(ctx: ExecuteContext, jv: JValue): MatrixReader = {
117117
implicit val formats: Formats = DefaultFormats
118118
(jv \ "name").extract[String] match {
119-
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(env.ctx, jv)
120-
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(env.ctx.fs, jv)
121-
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(env, jv)
122-
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(env.ctx, jv)
123-
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(env.ctx, jv)
119+
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(ctx, jv)
120+
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(ctx.fs, jv)
121+
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(ctx, jv)
122+
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(ctx, jv)
123+
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(ctx, jv)
124124
}
125125
}
126126

0 commit comments

Comments
 (0)