Skip to content

Commit 8760f47

Browse files
committed
[query] remove unnecessary references to HailContext
1 parent 33d0b37 commit 8760f47

40 files changed

+342
-283
lines changed

hail/hail/src/is/hail/HailContext.scala

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package is.hail
22

3-
import is.hail.backend.Backend
3+
import is.hail.backend.{Backend, ExecuteContext}
44
import is.hail.backend.spark.SparkBackend
55
import is.hail.expr.ir.functions.IRFunctionRegistry
66
import is.hail.io.fs.FS
7-
import is.hail.io.vcf._
8-
import is.hail.types.virtual._
97
import is.hail.utils._
108

119
import scala.reflect.ClassTag
@@ -17,33 +15,21 @@ import org.apache.log4j.{LogManager, PropertyConfigurator}
1715
import org.apache.spark._
1816
import org.apache.spark.executor.InputMetrics
1917
import org.apache.spark.rdd.RDD
20-
import org.json4s.Extraction
21-
import org.json4s.jackson.JsonMethods
22-
import sourcecode.Enclosing
2318

2419
case class FilePartition(index: Int, file: String) extends Partition
2520

2621
object HailContext {
27-
val tera: Long = 1024L * 1024L * 1024L * 1024L
2822

2923
val logFormat: String = "%d{yyyy-MM-dd HH:mm:ss.SSS} %c{1}: %p: %m%n"
3024

3125
private var theContext: HailContext = _
3226

33-
def isInitialized: Boolean = synchronized {
34-
theContext != null
35-
}
36-
3727
def get: HailContext = synchronized {
3828
assert(TaskContext.get() == null, "HailContext not available on worker")
3929
assert(theContext != null, "HailContext not initialized")
4030
theContext
4131
}
4232

43-
def backend: Backend = get.backend
44-
45-
def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark
46-
4733
def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = {
4834
org.apache.log4j.helpers.LogLog.setInternalDebugging(true)
4935
org.apache.log4j.helpers.LogLog.setQuietMode(false)
@@ -94,7 +80,7 @@ object HailContext {
9480

9581
def getOrCreate(backend: Backend): HailContext =
9682
synchronized {
97-
if (isInitialized) theContext
83+
if (theContext != null) theContext
9884
else HailContext(backend)
9985
}
10086

@@ -123,25 +109,25 @@ object HailContext {
123109
theContext
124110
}
125111

126-
def stop(): Unit = synchronized {
127-
IRFunctionRegistry.clearUserFunctions()
128-
backend.close()
129-
130-
theContext = null
131-
}
112+
def stop(): Unit =
113+
synchronized {
114+
IRFunctionRegistry.clearUserFunctions()
115+
theContext.backend.close()
116+
theContext = null
117+
}
132118

133119
def readPartitions[T: ClassTag](
134-
fs: FS,
120+
ctx: ExecuteContext,
135121
path: String,
136122
partFiles: IndexedSeq[String],
137123
read: (Int, InputStream, InputMetrics) => Iterator[T],
138124
optPartitioner: Option[Partitioner] = None,
139125
): RDD[T] = {
140126
val nPartitions = partFiles.length
141127

142-
val fsBc = fs.broadcast
128+
val fsBc = ctx.fsBc
143129

144-
new RDD[T](SparkBackend.sparkContext, Nil) {
130+
new RDD[T](ctx.backend.asSpark.sc, Nil) {
145131
def getPartitions: Array[Partition] =
146132
Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i)))
147133

@@ -192,12 +178,4 @@ class HailContext private (
192178
: Array[(String, Array[String])] =
193179
fileAndLineCounts(fs: FS, regex, files, maxLines).mapValues(_.map(_.value)).toArray
194180

195-
def parseVCFMetadata(fs: FS, file: String): Map[String, Map[String, Map[String, String]]] =
196-
LoadVCF.parseHeaderMetadata(fs, Set.empty, TFloat64, file)
197-
198-
def pyParseVCFMetadataJSON(fs: FS, file: String): String = {
199-
val metadata = LoadVCF.parseHeaderMetadata(fs, Set.empty, TFloat64, file)
200-
implicit val formats = defaultJSONFormats
201-
JsonMethods.compact(Extraction.decompose(metadata))
202-
}
203181
}

hail/hail/src/is/hail/backend/BackendUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class BackendUtils(
5252
val remainingPartitions =
5353
contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))
5454

55-
val backend = HailContext.backend
55+
val backend = HailContext.get.backend
5656
val mod = getModule(modID)
5757
val t = System.nanoTime()
5858
val (failureOpt, successes) =

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class ExecuteContext(
138138

139139
val stateManager = HailStateManager(references)
140140

141-
def fsBc: BroadcastValue[FS] = fs.broadcast
141+
lazy val fsBc: BroadcastValue[FS] = backend.broadcast(fs)
142142

143143
val memo: mutable.Map[Any, Any] = new mutable.HashMap[Any, Any]()
144144

hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
138138
): Unit = {
139139
void {
140140
withExecuteContext() { ctx =>
141-
val rm = linalg.RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
141+
val rm = linalg.RowMatrix.readBlockMatrix(ctx, pathIn, partitionSize)
142142
entries match {
143143
case "full" =>
144144
rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType)
@@ -225,7 +225,8 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
225225
def pyToDF(s: String): DataFrame =
226226
withExecuteContext(selfContainedExecution = false) { ctx =>
227227
val tir = IRParser.parse_table_ir(ctx, s)
228-
Interpret(tir, ctx).toDF()
228+
val tv = Interpret(tir, ctx)
229+
tv.toDF(ctx)
229230
}._1
230231

231232
def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ object SparkBackend {
8383

8484
private var theSparkBackend: SparkBackend = _
8585

86-
def sparkContext(implicit E: Enclosing): SparkContext = HailContext.sparkBackend.sc
86+
def sparkContext(implicit E: Enclosing): SparkContext =
87+
synchronized {
88+
if (theSparkBackend == null) throw new IllegalStateException(E.value)
89+
else theSparkBackend.sc
90+
}
8791

8892
def checkSparkCompatibility(jarVersion: String, sparkVersion: String): Unit = {
8993
def majorMinor(version: String): String = version.split("\\.", 3).take(2).mkString(".")

0 commit comments

Comments
 (0)