Skip to content

Commit 901fafe

Browse files
committed
[query] Expose references via ExecuteContext
1 parent 61dfa8c commit 901fafe

27 files changed

+345
-380
lines changed

hail/src/main/scala/is/hail/HailFeatureFlags.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class HailFeatureFlags private (
6868
flags.update(flag, value)
6969
}
7070

71+
def +(feature: (String, String)): HailFeatureFlags =
72+
new HailFeatureFlags(flags + (feature._1 -> feature._2))
73+
7174
def get(flag: String): String = flags(flag)
7275

7376
def lookup(flag: String): Option[String] =

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ abstract class Backend extends Closeable {
105105
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
106106
: CompiledFunction[T]
107107

108-
def references: mutable.Map[String, ReferenceGenome]
109-
110108
def lowerDistributedSort(
111109
ctx: ExecuteContext,
112110
stage: TableStage,
@@ -181,10 +179,8 @@ abstract class Backend extends Closeable {
181179
): Array[Byte] =
182180
withExecuteContext { ctx =>
183181
jsonToBytes {
184-
Extraction.decompose {
185-
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
186-
xContigs, yContigs, mtContigs, parInput).toJSON
187-
}(defaultJSONFormats)
182+
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
183+
xContigs, yContigs, mtContigs, parInput).toJSON
188184
}
189185
}
190186

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ object ExecuteContext {
6363
tmpdir: String,
6464
localTmpdir: String,
6565
backend: Backend,
66+
references: Map[String, ReferenceGenome],
6667
fs: FS,
6768
timer: ExecutionTimer,
6869
tempFileManager: TempFileManager,
@@ -79,6 +80,7 @@ object ExecuteContext {
7980
tmpdir,
8081
localTmpdir,
8182
backend,
83+
references,
8284
fs,
8385
region,
8486
timer,
@@ -107,6 +109,7 @@ class ExecuteContext(
107109
val tmpdir: String,
108110
val localTmpdir: String,
109111
val backend: Backend,
112+
val references: Map[String, ReferenceGenome],
110113
val fs: FS,
111114
val r: Region,
112115
val timer: ExecutionTimer,
@@ -128,7 +131,7 @@ class ExecuteContext(
128131
)
129132
}
130133

131-
def stateManager = HailStateManager(backend.references.toMap)
134+
val stateManager = HailStateManager(references)
132135

133136
val tempFileManager: TempFileManager =
134137
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
@@ -154,8 +157,6 @@ class ExecuteContext(
154157

155158
def getFlag(name: String): String = flags.get(name)
156159

157-
def getReference(name: String): ReferenceGenome = backend.references(name)
158-
159160
def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null
160161

161162
def shouldNotLogIR(): Boolean = flags.get("no_ir_logging") != null
@@ -174,6 +175,7 @@ class ExecuteContext(
174175
tmpdir: String = this.tmpdir,
175176
localTmpdir: String = this.localTmpdir,
176177
backend: Backend = this.backend,
178+
references: Map[String, ReferenceGenome] = this.references,
177179
fs: FS = this.fs,
178180
r: Region = this.r,
179181
timer: ExecutionTimer = this.timer,
@@ -189,6 +191,7 @@ class ExecuteContext(
189191
tmpdir,
190192
localTmpdir,
191193
backend,
194+
references,
192195
fs,
193196
r,
194197
timer,

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class LocalBackend(
8686
override val references: mutable.Map[String, ReferenceGenome],
8787
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
8888

89+
override def backend: Backend = this
8990
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
9091
override def longLifeTempFileManager: TempFileManager = null
9192

@@ -101,6 +102,7 @@ class LocalBackend(
101102
tmpdir,
102103
tmpdir,
103104
this,
105+
references.toMap,
104106
fs,
105107
timer,
106108
null,

hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ import is.hail.expr.ir.{
1010
}
1111
import is.hail.expr.ir.IRParser.parseType
1212
import is.hail.expr.ir.functions.IRFunctionRegistry
13+
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1314
import is.hail.linalg.RowMatrix
1415
import is.hail.types.physical.PStruct
1516
import is.hail.types.virtual.{TArray, TInterval}
1617
import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval}
1718
import is.hail.variant.ReferenceGenome
1819

20+
import scala.collection.mutable
1921
import scala.jdk.CollectionConverters.{
2022
asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter,
2123
}
@@ -29,7 +31,10 @@ import org.json4s.Formats
2931
import org.json4s.jackson.{JsonMethods, Serialization}
3032
import sourcecode.Enclosing
3133

32-
trait Py4JBackendExtensions { this: Backend =>
34+
trait Py4JBackendExtensions {
35+
def backend: Backend
36+
def references: mutable.Map[String, ReferenceGenome]
37+
def persistedIR: mutable.Map[Int, BaseIR]
3338
def flags: HailFeatureFlags
3439
def longLifeTempFileManager: TempFileManager
3540

@@ -59,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend =>
5964
persistedIR.remove(id)
6065

6166
def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
62-
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))
67+
backend.withExecuteContext { ctx =>
68+
references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile))
69+
}
6370

6471
def pyRemoveSequence(name: String): Unit =
6572
references(name).removeSequence()
@@ -74,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend =>
7481
partitionSize: java.lang.Integer,
7582
entries: String,
7683
): Unit =
77-
withExecuteContext { ctx =>
84+
backend.withExecuteContext { ctx =>
7885
val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
7986
entries match {
8087
case "full" =>
@@ -112,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend =>
112119
returnType: String,
113120
bodyStr: String,
114121
): Unit = {
115-
withExecuteContext { ctx =>
122+
backend.withExecuteContext { ctx =>
116123
IRFunctionRegistry.registerIR(
117124
ctx,
118125
name,
@@ -126,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend =>
126133
}
127134

128135
def pyExecuteLiteral(irStr: String): Int =
129-
withExecuteContext { ctx =>
136+
backend.withExecuteContext { ctx =>
130137
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
131138
assert(ir.typ.isRealizable)
132-
execute(ctx, ir) match {
139+
backend.execute(ctx, ir) match {
133140
case Left(_) => throw new HailException("Can't create literal")
134141
case Right((pt, addr)) =>
135142
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
@@ -158,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend =>
158165
}
159166

160167
def pyToDF(s: String): DataFrame =
161-
withExecuteContext { ctx =>
168+
backend.withExecuteContext { ctx =>
162169
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
163170
Interpret(tir, ctx).toDF()
164171
}
165172

166173
def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
167-
withExecuteContext { ctx =>
174+
backend.withExecuteContext { ctx =>
168175
log.info("pyReadMultipleMatrixTables: got query")
169176
val kvs = JsonMethods.parse(jsonQuery) match {
170177
case json4s.JObject(values) => values.toMap
@@ -193,19 +200,24 @@ trait Py4JBackendExtensions { this: Backend =>
193200
addReference(ReferenceGenome.fromJSON(jsonConfig))
194201

195202
def pyRemoveReference(name: String): Unit =
196-
references.remove(name)
203+
removeReference(name)
197204

198205
def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
199-
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))
206+
backend.withExecuteContext { ctx =>
207+
references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile))
208+
}
200209

201210
def pyRemoveLiftover(name: String, destRGName: String): Unit =
202211
references(name).removeLiftover(destRGName)
203212

204213
private[this] def addReference(rg: ReferenceGenome): Unit =
205214
ReferenceGenome.addFatalOnCollision(references, FastSeq(rg))
206215

216+
private[this] def removeReference(name: String): Unit =
217+
references -= name
218+
207219
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
208-
withExecuteContext { ctx =>
220+
backend.withExecuteContext { ctx =>
209221
IRParser.parse_value_ir(
210222
s,
211223
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
@@ -231,7 +243,7 @@ trait Py4JBackendExtensions { this: Backend =>
231243
}
232244

233245
def loadReferencesFromDataset(path: String): Array[Byte] =
234-
withExecuteContext { ctx =>
246+
backend.withExecuteContext { ctx =>
235247
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
236248
ReferenceGenome.addFatalOnCollision(references, rgs)
237249

@@ -245,7 +257,7 @@ trait Py4JBackendExtensions { this: Backend =>
245257
f: ExecuteContext => T
246258
)(implicit E: Enclosing
247259
): T =
248-
withExecuteContext { ctx =>
260+
backend.withExecuteContext { ctx =>
249261
val tempFileManager = longLifeTempFileManager
250262
if (selfContainedExecution && tempFileManager != null) f(ctx)
251263
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import is.hail.expr.ir.analyses.SemanticHash
1313
import is.hail.expr.ir.functions.IRFunctionRegistry
1414
import is.hail.expr.ir.lowering._
1515
import is.hail.io.fs._
16+
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1617
import is.hail.linalg.BlockMatrix
1718
import is.hail.services.{BatchClient, JobGroupRequest, _}
1819
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
@@ -93,7 +94,16 @@ object ServiceBackend {
9394
rpcConfig.custom_references.map(ReferenceGenome.fromJSON),
9495
)
9596

96-
val backend = new ServiceBackend(
97+
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
98+
liftoversForSource.foreach { case (destGenome, chainFile) =>
99+
references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile))
100+
}
101+
}
102+
rpcConfig.sequences.foreach { case (rg, seq) =>
103+
references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index))
104+
}
105+
106+
new ServiceBackend(
97107
JarUrl(jarLocation),
98108
name,
99109
theHailClassLoader,
@@ -106,27 +116,14 @@ object ServiceBackend {
106116
backendContext,
107117
scratchDir,
108118
)
109-
110-
backend.withExecuteContext { ctx =>
111-
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
112-
liftoversForSource.foreach { case (destGenome, chainFile) =>
113-
references(sourceGenome).addLiftover(ctx, chainFile, destGenome)
114-
}
115-
}
116-
rpcConfig.sequences.foreach { case (rg, seq) =>
117-
references(rg).addSequence(ctx, seq.fasta, seq.index)
118-
}
119-
}
120-
121-
backend
122119
}
123120
}
124121

125122
class ServiceBackend(
126123
val jarSpec: JarSpec,
127124
var name: String,
128125
val theHailClassLoader: HailClassLoader,
129-
override val references: mutable.Map[String, ReferenceGenome],
126+
val references: mutable.Map[String, ReferenceGenome],
130127
val batchClient: BatchClient,
131128
val batchConfig: BatchConfig,
132129
val flags: HailFeatureFlags,
@@ -397,6 +394,7 @@ class ServiceBackend(
397394
tmpdir,
398395
"file:///tmp",
399396
this,
397+
references.toMap,
400398
fs,
401399
timer,
402400
null,

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ class SparkBackend(
345345
new HadoopFS(new SerializableHadoopConfiguration(conf))
346346
}
347347

348+
override def backend: Backend = this
348349
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
349350

350351
override val longLifeTempFileManager: TempFileManager =
@@ -374,6 +375,7 @@ class SparkBackend(
374375
tmpdir,
375376
localTmpdir,
376377
this,
378+
references.toMap,
377379
fs,
378380
region,
379381
timer,
@@ -393,6 +395,7 @@ class SparkBackend(
393395
tmpdir,
394396
localTmpdir,
395397
this,
398+
references.toMap,
396399
fs,
397400
timer,
398401
null,

hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) {
7171
}
7272

7373
def referenceGenomes(): IndexedSeq[ReferenceGenome] =
74-
rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name)
74+
rgContainers.keys.map(ctx.references(_)).toIndexedSeq.sortBy(_.name)
7575

7676
def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] =
7777
rgContainers.toFastSeq.sortBy(_._1).map(_._2)

hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
646646
BoolValue.fromComparison(l, op).restrict(keySet)
647647
case Contig(rgStr) =>
648648
// locus contig equality comparison
649-
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match {
649+
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.references(rgStr)) match {
650650
case Some(i) =>
651651
val b = BoolValue(
652652
KeySet(i),
@@ -670,7 +670,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
670670
case Position(rgStr) =>
671671
// locus position comparison
672672
val posBoolValue = BoolValue.fromComparison(l, op)
673-
val rg = ctx.getReference(rgStr)
673+
val rg = ctx.references(rgStr)
674674
val b = BoolValue(
675675
KeySet(liftPosIntervalsToLocus(posBoolValue.trueBound, rg, ctx)),
676676
KeySet(liftPosIntervalsToLocus(posBoolValue.falseBound, rg, ctx)),

hail/src/main/scala/is/hail/expr/ir/MatrixValue.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ case class MatrixValue(
210210
ReferenceGenome.exportReferences(
211211
fs,
212212
refPath,
213-
ReferenceGenome.getReferences(t).map(ctx.getReference(_)),
213+
ReferenceGenome.getReferences(t).map(ctx.references(_)),
214214
)
215215
}
216216

0 commit comments

Comments
 (0)