@@ -4,8 +4,8 @@ import is.hail.HailFeatureFlags
44import is .hail .backend .{Backend , ExecuteContext , NonOwningTempFileManager , TempFileManager }
55import is .hail .expr .{JSONAnnotationImpex , SparkAnnotationImpex }
66import is .hail .expr .ir .{
7- BaseIR , BindingEnv , BlockMatrixIR , IR , IRParser , IRParserEnvironment , Interpret , MatrixIR ,
8- MatrixNativeReader , MatrixRead , Name , NativeReaderOptions , TableIR , TableLiteral , TableValue ,
7+ BaseIR , BindingEnv , BlockMatrixIR , IR , IRParser , Interpret , MatrixIR , MatrixNativeReader ,
8+ MatrixRead , Name , NativeReaderOptions , TableIR , TableLiteral , TableValue ,
99}
1010import is .hail .expr .ir .IRParser .parseType
1111import is .hail .expr .ir .defs .{EncodedLiteral , GetFieldByIdx }
@@ -34,7 +34,6 @@ import sourcecode.Enclosing
3434trait Py4JBackendExtensions {
3535 def backend : Backend
3636 def references : mutable.Map [String , ReferenceGenome ]
37- def persistedIR : mutable.Map [Int , BaseIR ]
3837 def flags : HailFeatureFlags
3938 def longLifeTempFileManager : TempFileManager
4039
@@ -54,14 +53,14 @@ trait Py4JBackendExtensions {
5453 irID
5554 }
5655
57- private [this ] def addJavaIR (ir : BaseIR ): Int = {
56+ private [this ] def addJavaIR (ctx : ExecuteContext , ir : BaseIR ): Int = {
5857 val id = nextIRID()
59- persistedIR += (id -> ir)
58+ ctx. IrCache += (id -> ir)
6059 id
6160 }
6261
6362 def pyRemoveJavaIR (id : Int ): Unit =
64- persistedIR. remove(id)
63+ backend.withExecuteContext(_. IrCache . remove(id) )
6564
6665 def pyAddSequence (name : String , fastaFile : String , indexFile : String ): Unit =
6766 backend.withExecuteContext { ctx =>
@@ -118,7 +117,7 @@ trait Py4JBackendExtensions {
118117 argTypeStrs : java.util.ArrayList [String ],
119118 returnType : String ,
120119 bodyStr : String ,
121- ): Unit = {
120+ ): Unit =
122121 backend.withExecuteContext { ctx =>
123122 IRFunctionRegistry .registerIR(
124123 ctx,
@@ -130,17 +129,16 @@ trait Py4JBackendExtensions {
130129 bodyStr,
131130 )
132131 }
133- }
134132
135133 def pyExecuteLiteral (irStr : String ): Int =
136134 backend.withExecuteContext { ctx =>
137- val ir = IRParser .parse_value_ir(irStr, IRParserEnvironment ( ctx, persistedIR.toMap) )
135+ val ir = IRParser .parse_value_ir(ctx, irStr )
138136 assert(ir.typ.isRealizable)
139137 backend.execute(ctx, ir) match {
140138 case Left (_) => throw new HailException (" Can't create literal" )
141139 case Right ((pt, addr)) =>
142140 val field = GetFieldByIdx (EncodedLiteral .fromPTypeAndAddress(pt, addr, ctx), 0 )
143- addJavaIR(field)
141+ addJavaIR(ctx, field)
144142 }
145143 }
146144
@@ -159,14 +157,14 @@ trait Py4JBackendExtensions {
159157 ),
160158 ctx.theHailClassLoader,
161159 )
162- val id = addJavaIR(tir)
160+ val id = addJavaIR(ctx, tir)
163161 (id, JsonMethods .compact(tir.typ.toJSON))
164162 }
165163 }
166164
167165 def pyToDF (s : String ): DataFrame =
168166 backend.withExecuteContext { ctx =>
169- val tir = IRParser .parse_table_ir(s, IRParserEnvironment ( ctx, irMap = persistedIR.toMap) )
167+ val tir = IRParser .parse_table_ir(ctx, s )
170168 Interpret (tir, ctx).toDF()
171169 }
172170
@@ -219,27 +217,23 @@ trait Py4JBackendExtensions {
219217 def parse_value_ir (s : String , refMap : java.util.Map [String , String ]): IR =
220218 backend.withExecuteContext { ctx =>
221219 IRParser .parse_value_ir(
220+ ctx,
222221 s,
223- IRParserEnvironment (ctx, irMap = persistedIR.toMap),
224222 BindingEnv .eval(refMap.asScala.toMap.map { case (n, t) =>
225223 Name (n) -> IRParser .parseType(t)
226224 }.toSeq: _* ),
227225 )
228226 }
229227
230228 def parse_table_ir (s : String ): TableIR =
231- withExecuteContext(selfContainedExecution = false ) { ctx =>
232- IRParser .parse_table_ir(s, IRParserEnvironment (ctx, irMap = persistedIR.toMap))
233- }
229+ withExecuteContext(selfContainedExecution = false )(ctx => IRParser .parse_table_ir(ctx, s))
234230
235231 def parse_matrix_ir (s : String ): MatrixIR =
236- withExecuteContext(selfContainedExecution = false ) { ctx =>
237- IRParser .parse_matrix_ir(s, IRParserEnvironment (ctx, irMap = persistedIR.toMap))
238- }
232+ withExecuteContext(selfContainedExecution = false )(ctx => IRParser .parse_matrix_ir(ctx, s))
239233
240234 def parse_blockmatrix_ir (s : String ): BlockMatrixIR =
241235 withExecuteContext(selfContainedExecution = false ) { ctx =>
242- IRParser .parse_blockmatrix_ir(s, IRParserEnvironment ( ctx, irMap = persistedIR.toMap) )
236+ IRParser .parse_blockmatrix_ir(ctx, s )
243237 }
244238
245239 def loadReferencesFromDataset (path : String ): Array [Byte ] =
0 commit comments