Skip to content

Commit 0bdffbc

Browse files
authored
[query] Expose references via ExecuteContext (#14686)
This change is split out from a larger refactoring effort on the various Backend implementations. The goals of this effort are to provide query-level configuration to the backend that's currently tied to the lifetime of a backend, reduce code duplication and reduce state duplication. In this change, I'm restoring references to the execute context [1] and decoupling them from the backend. In a future change, they'll be lifted out of the backend implementations altogether. This is to reduce the surface area of the Backend interface to the details that are actually different. Both the Local and Spark backend have state that's manipulated from python via various py methods. These pollute the Backend interface [2] and so have been extracted into the trait Py4JBackendExtensions. In future changes, this will become a facade that owns state set in python. Notes [1] "Restoring" old behaviour I foolishly removed in fe5ed32 [2] "Pollute" in that they obfuscate what's different about backend query plan and execution
1 parent 5d9c642 commit 0bdffbc

27 files changed

+296
-333
lines changed

hail/hail/src/is/hail/HailFeatureFlags.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class HailFeatureFlags private (
6868
flags.update(flag, value)
6969
}
7070

71+
def +(feature: (String, String)): HailFeatureFlags =
72+
new HailFeatureFlags(flags + (feature._1 -> feature._2))
73+
7174
def get(flag: String): String = flags(flag)
7275

7376
def lookup(flag: String): Option[String] =

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ abstract class Backend extends Closeable {
105105
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
106106
: CompiledFunction[T]
107107

108-
def references: mutable.Map[String, ReferenceGenome]
109-
110108
def lowerDistributedSort(
111109
ctx: ExecuteContext,
112110
stage: TableStage,
@@ -181,10 +179,8 @@ abstract class Backend extends Closeable {
181179
): Array[Byte] =
182180
withExecuteContext { ctx =>
183181
jsonToBytes {
184-
Extraction.decompose {
185-
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
186-
xContigs, yContigs, mtContigs, parInput).toJSON
187-
}(defaultJSONFormats)
182+
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
183+
xContigs, yContigs, mtContigs, parInput).toJSON
188184
}
189185
}
190186

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ object ExecuteContext {
6363
tmpdir: String,
6464
localTmpdir: String,
6565
backend: Backend,
66+
references: Map[String, ReferenceGenome],
6667
fs: FS,
6768
timer: ExecutionTimer,
6869
tempFileManager: TempFileManager,
@@ -79,6 +80,7 @@ object ExecuteContext {
7980
tmpdir,
8081
localTmpdir,
8182
backend,
83+
references,
8284
fs,
8385
region,
8486
timer,
@@ -107,6 +109,7 @@ class ExecuteContext(
107109
val tmpdir: String,
108110
val localTmpdir: String,
109111
val backend: Backend,
112+
val references: Map[String, ReferenceGenome],
110113
val fs: FS,
111114
val r: Region,
112115
val timer: ExecutionTimer,
@@ -128,7 +131,7 @@ class ExecuteContext(
128131
)
129132
}
130133

131-
def stateManager = HailStateManager(backend.references.toMap)
134+
val stateManager = HailStateManager(references)
132135

133136
val tempFileManager: TempFileManager =
134137
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
@@ -154,8 +157,6 @@ class ExecuteContext(
154157

155158
def getFlag(name: String): String = flags.get(name)
156159

157-
def getReference(name: String): ReferenceGenome = backend.references(name)
158-
159160
def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null
160161

161162
def shouldNotLogIR(): Boolean = flags.get("no_ir_logging") != null
@@ -174,6 +175,7 @@ class ExecuteContext(
174175
tmpdir: String = this.tmpdir,
175176
localTmpdir: String = this.localTmpdir,
176177
backend: Backend = this.backend,
178+
references: Map[String, ReferenceGenome] = this.references,
177179
fs: FS = this.fs,
178180
r: Region = this.r,
179181
timer: ExecutionTimer = this.timer,
@@ -189,6 +191,7 @@ class ExecuteContext(
189191
tmpdir,
190192
localTmpdir,
191193
backend,
194+
references,
192195
fs,
193196
r,
194197
timer,

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class LocalBackend(
8787
override val references: mutable.Map[String, ReferenceGenome],
8888
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
8989

90+
override def backend: Backend = this
9091
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
9192
override def longLifeTempFileManager: TempFileManager = null
9293

@@ -102,6 +103,7 @@ class LocalBackend(
102103
tmpdir,
103104
tmpdir,
104105
this,
106+
references.toMap,
105107
fs,
106108
timer,
107109
null,

hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ import is.hail.expr.ir.{
1010
import is.hail.expr.ir.IRParser.parseType
1111
import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx}
1212
import is.hail.expr.ir.functions.IRFunctionRegistry
13+
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1314
import is.hail.linalg.RowMatrix
1415
import is.hail.types.physical.PStruct
1516
import is.hail.types.virtual.{TArray, TInterval}
1617
import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval}
1718
import is.hail.variant.ReferenceGenome
1819

20+
import scala.collection.mutable
1921
import scala.jdk.CollectionConverters.{
2022
asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter,
2123
}
@@ -29,7 +31,10 @@ import org.json4s.Formats
2931
import org.json4s.jackson.{JsonMethods, Serialization}
3032
import sourcecode.Enclosing
3133

32-
trait Py4JBackendExtensions { this: Backend =>
34+
trait Py4JBackendExtensions {
35+
def backend: Backend
36+
def references: mutable.Map[String, ReferenceGenome]
37+
def persistedIR: mutable.Map[Int, BaseIR]
3338
def flags: HailFeatureFlags
3439
def longLifeTempFileManager: TempFileManager
3540

@@ -59,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend =>
5964
persistedIR.remove(id)
6065

6166
def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
62-
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))
67+
backend.withExecuteContext { ctx =>
68+
references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile))
69+
}
6370

6471
def pyRemoveSequence(name: String): Unit =
6572
references(name).removeSequence()
@@ -74,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend =>
7481
partitionSize: java.lang.Integer,
7582
entries: String,
7683
): Unit =
77-
withExecuteContext { ctx =>
84+
backend.withExecuteContext { ctx =>
7885
val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
7986
entries match {
8087
case "full" =>
@@ -112,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend =>
112119
returnType: String,
113120
bodyStr: String,
114121
): Unit = {
115-
withExecuteContext { ctx =>
122+
backend.withExecuteContext { ctx =>
116123
IRFunctionRegistry.registerIR(
117124
ctx,
118125
name,
@@ -126,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend =>
126133
}
127134

128135
def pyExecuteLiteral(irStr: String): Int =
129-
withExecuteContext { ctx =>
136+
backend.withExecuteContext { ctx =>
130137
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
131138
assert(ir.typ.isRealizable)
132-
execute(ctx, ir) match {
139+
backend.execute(ctx, ir) match {
133140
case Left(_) => throw new HailException("Can't create literal")
134141
case Right((pt, addr)) =>
135142
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
@@ -158,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend =>
158165
}
159166

160167
def pyToDF(s: String): DataFrame =
161-
withExecuteContext { ctx =>
168+
backend.withExecuteContext { ctx =>
162169
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
163170
Interpret(tir, ctx).toDF()
164171
}
165172

166173
def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
167-
withExecuteContext { ctx =>
174+
backend.withExecuteContext { ctx =>
168175
log.info("pyReadMultipleMatrixTables: got query")
169176
val kvs = JsonMethods.parse(jsonQuery) match {
170177
case json4s.JObject(values) => values.toMap
@@ -193,19 +200,24 @@ trait Py4JBackendExtensions { this: Backend =>
193200
addReference(ReferenceGenome.fromJSON(jsonConfig))
194201

195202
def pyRemoveReference(name: String): Unit =
196-
references.remove(name)
203+
removeReference(name)
197204

198205
def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
199-
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))
206+
backend.withExecuteContext { ctx =>
207+
references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile))
208+
}
200209

201210
def pyRemoveLiftover(name: String, destRGName: String): Unit =
202211
references(name).removeLiftover(destRGName)
203212

204213
private[this] def addReference(rg: ReferenceGenome): Unit =
205214
ReferenceGenome.addFatalOnCollision(references, FastSeq(rg))
206215

216+
private[this] def removeReference(name: String): Unit =
217+
references -= name
218+
207219
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
208-
withExecuteContext { ctx =>
220+
backend.withExecuteContext { ctx =>
209221
IRParser.parse_value_ir(
210222
s,
211223
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
@@ -231,7 +243,7 @@ trait Py4JBackendExtensions { this: Backend =>
231243
}
232244

233245
def loadReferencesFromDataset(path: String): Array[Byte] =
234-
withExecuteContext { ctx =>
246+
backend.withExecuteContext { ctx =>
235247
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
236248
ReferenceGenome.addFatalOnCollision(references, rgs)
237249

@@ -245,7 +257,7 @@ trait Py4JBackendExtensions { this: Backend =>
245257
f: ExecuteContext => T
246258
)(implicit E: Enclosing
247259
): T =
248-
withExecuteContext { ctx =>
260+
backend.withExecuteContext { ctx =>
249261
val tempFileManager = longLifeTempFileManager
250262
if (selfContainedExecution && tempFileManager != null) f(ctx)
251263
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import is.hail.expr.ir.defs.MakeTuple
1313
import is.hail.expr.ir.functions.IRFunctionRegistry
1414
import is.hail.expr.ir.lowering._
1515
import is.hail.io.fs._
16+
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1617
import is.hail.linalg.BlockMatrix
1718
import is.hail.services.{BatchClient, JobGroupRequest, _}
1819
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
@@ -93,7 +94,16 @@ object ServiceBackend {
9394
rpcConfig.custom_references.map(ReferenceGenome.fromJSON),
9495
)
9596

96-
val backend = new ServiceBackend(
97+
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
98+
liftoversForSource.foreach { case (destGenome, chainFile) =>
99+
references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile))
100+
}
101+
}
102+
rpcConfig.sequences.foreach { case (rg, seq) =>
103+
references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index))
104+
}
105+
106+
new ServiceBackend(
97107
JarUrl(jarLocation),
98108
name,
99109
theHailClassLoader,
@@ -106,27 +116,14 @@ object ServiceBackend {
106116
backendContext,
107117
scratchDir,
108118
)
109-
110-
backend.withExecuteContext { ctx =>
111-
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
112-
liftoversForSource.foreach { case (destGenome, chainFile) =>
113-
references(sourceGenome).addLiftover(ctx, chainFile, destGenome)
114-
}
115-
}
116-
rpcConfig.sequences.foreach { case (rg, seq) =>
117-
references(rg).addSequence(ctx, seq.fasta, seq.index)
118-
}
119-
}
120-
121-
backend
122119
}
123120
}
124121

125122
class ServiceBackend(
126123
val jarSpec: JarSpec,
127124
var name: String,
128125
val theHailClassLoader: HailClassLoader,
129-
override val references: mutable.Map[String, ReferenceGenome],
126+
val references: mutable.Map[String, ReferenceGenome],
130127
val batchClient: BatchClient,
131128
val batchConfig: BatchConfig,
132129
val flags: HailFeatureFlags,
@@ -397,6 +394,7 @@ class ServiceBackend(
397394
tmpdir,
398395
"file:///tmp",
399396
this,
397+
references.toMap,
400398
fs,
401399
timer,
402400
null,

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ class SparkBackend(
346346
new HadoopFS(new SerializableHadoopConfiguration(conf))
347347
}
348348

349+
override def backend: Backend = this
349350
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
350351

351352
override val longLifeTempFileManager: TempFileManager =
@@ -375,6 +376,7 @@ class SparkBackend(
375376
tmpdir,
376377
localTmpdir,
377378
this,
379+
references.toMap,
378380
fs,
379381
region,
380382
timer,
@@ -394,6 +396,7 @@ class SparkBackend(
394396
tmpdir,
395397
localTmpdir,
396398
this,
399+
references.toMap,
397400
fs,
398401
timer,
399402
null,

hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) {
7272
}
7373

7474
def referenceGenomes(): IndexedSeq[ReferenceGenome] =
75-
rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name)
75+
rgContainers.keys.map(ctx.references(_)).toIndexedSeq.sortBy(_.name)
7676

7777
def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] =
7878
rgContainers.toFastSeq.sortBy(_._1).map(_._2)

hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
647647
BoolValue.fromComparison(l, op).restrict(keySet)
648648
case Contig(rgStr) =>
649649
// locus contig equality comparison
650-
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match {
650+
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.references(rgStr)) match {
651651
case Some(i) =>
652652
val b = BoolValue(
653653
KeySet(i),
@@ -671,7 +671,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
671671
case Position(rgStr) =>
672672
// locus position comparison
673673
val posBoolValue = BoolValue.fromComparison(l, op)
674-
val rg = ctx.getReference(rgStr)
674+
val rg = ctx.references(rgStr)
675675
val b = BoolValue(
676676
KeySet(liftPosIntervalsToLocus(posBoolValue.trueBound, rg, ctx)),
677677
KeySet(liftPosIntervalsToLocus(posBoolValue.falseBound, rg, ctx)),

hail/hail/src/is/hail/expr/ir/MatrixValue.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ case class MatrixValue(
210210
ReferenceGenome.exportReferences(
211211
fs,
212212
refPath,
213-
ReferenceGenome.getReferences(t).map(ctx.getReference(_)),
213+
ReferenceGenome.getReferences(t).map(ctx.references(_)),
214214
)
215215
}
216216

0 commit comments

Comments
 (0)