Skip to content

Commit c11b3e0

Browse files
committed
[query] remove Backend from HailContext
1 parent 046213b commit c11b3e0

File tree

10 files changed

+47
-43
lines changed

10 files changed

+47
-43
lines changed

hail/hail/src/is/hail/HailContext.scala

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,14 @@
11
package is.hail
22

3-
import is.hail.backend.Backend
43
import is.hail.backend.spark.SparkBackend
54
import is.hail.expr.ir.functions.IRFunctionRegistry
65
import is.hail.io.fs.FS
76
import is.hail.utils._
87

9-
import org.apache.spark._
10-
118
object HailContext {
129

1310
private var theContext: HailContext = _
1411

15-
def get: HailContext = synchronized {
16-
assert(TaskContext.get() == null, "HailContext not available on worker")
17-
assert(theContext != null, "HailContext not initialized")
18-
theContext
19-
}
20-
2112
def checkJavaVersion(): Unit = {
2213
val javaVersion = raw"(\d+)\.(\d+)\.(\d+).*".r
2314
val versionString = System.getProperty("java.version")
@@ -36,13 +27,13 @@ object HailContext {
3627
}
3728
}
3829

39-
def getOrCreate(backend: Backend): HailContext =
30+
def getOrCreate: HailContext =
4031
synchronized {
4132
if (theContext != null) theContext
42-
else HailContext(backend)
33+
else apply
4334
}
4435

45-
def apply(backend: Backend): HailContext = synchronized {
36+
def apply: HailContext = synchronized {
4637
require(theContext == null)
4738
checkJavaVersion()
4839

@@ -60,7 +51,7 @@ object HailContext {
6051
)
6152
}
6253

63-
theContext = new HailContext(backend)
54+
theContext = new HailContext
6455

6556
info(s"Running Hail version $HAIL_PRETTY_VERSION")
6657

@@ -70,14 +61,11 @@ object HailContext {
7061
def stop(): Unit =
7162
synchronized {
7263
IRFunctionRegistry.clearUserFunctions()
73-
theContext.backend.close()
7464
theContext = null
7565
}
7666
}
7767

78-
class HailContext private (
79-
var backend: Backend
80-
) {
68+
class HailContext {
8169
def stop(): Unit = HailContext.stop()
8270

8371
private[this] def fileAndLineCounts(

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ object Backend {
4343
assert(t.isFieldDefined(off, 0))
4444
codec.encode(ctx, elementType, t.loadField(off, 0), os)
4545
}
46+
47+
// Currently required by `BackendUtils.collectDArray`
48+
private[this] var instance: Backend = _
49+
50+
def set(b: Backend): Unit =
51+
synchronized { instance = b }
52+
53+
def get: Backend =
54+
synchronized {
55+
assert(instance != null)
56+
instance
57+
}
4658
}
4759

4860
abstract class BroadcastValue[T] { def value: T }

hail/hail/src/is/hail/backend/BackendUtils.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package is.hail.backend
22

3-
import is.hail.HailContext
43
import is.hail.annotations.Region
54
import is.hail.asm4s._
65
import is.hail.backend.local.LocalTaskContext
@@ -52,7 +51,7 @@ class BackendUtils(
5251
val remainingPartitions =
5352
contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))
5453

55-
val backend = HailContext.get.backend
54+
val backend = Backend.get
5655
val mod = getModule(modID)
5756
val t = System.nanoTime()
5857
val (failureOpt, successes) =

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ class ExecuteContext(
149149
)(implicit E: Enclosing
150150
): T =
151151
using(new LocalTaskContext(0, 0)) { tc =>
152-
time {
153-
f(theHailClassLoader, fs, tc, r)
154-
}
152+
Backend.set(backend)
153+
try time(f(theHailClassLoader, fs, tc, r))
154+
finally Backend.set(null)
155155
}
156156

157157
def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String =

hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
194194
jobConfig,
195195
)
196196

197-
HailContext(backend)
197+
HailContext.getOrCreate
198198
log.info("HailContext initialized.")
199199

200200
// FIXME: when can the classloader be shared? (optimizer benefits!)
@@ -211,7 +211,10 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
211211
payload,
212212
)
213213
)
214-
finally HailContext.stop()
214+
finally {
215+
backend.close()
216+
HailContext.stop()
217+
}
215218
}
216219
}
217220

hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
284284
compiledCodeCache.clear()
285285
irCache.clear()
286286
coercerCache.clear()
287+
backend.close()
287288
}
288289

289290
private[this] def removeReference(name: String): Unit =

hail/hail/src/is/hail/backend/service/Worker.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ object Worker {
177177
timer.end("readInputs")
178178
timer.start("executeFunction")
179179

180-
// FIXME: workers should not have backends, but some things do need hail contexts
181-
HailContext.getOrCreate(new ServiceBackend(null, null, null, null, null))
180+
HailContext.getOrCreate
182181
val result =
183182
try
184183
using(new ServiceTaskContext(i)) { htc =>

hail/hail/test/src/is/hail/HailSuite.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,26 @@ class HailSuite extends TestNGSuite with TestUtils {
5454
def setupHailContext(): Unit = {
5555
Logging.configureLogging("/tmp/hail.log", quiet = false, append = false)
5656
RVD.CheckRvdKeyOrderingForTesting = true
57-
val backend = SparkBackend(
58-
sc = new SparkContext(
59-
SparkBackend.createSparkConf(
60-
appName = "Hail.TestNG",
61-
master = System.getProperty("hail.master"),
62-
local = "local[2]",
63-
blockSize = 0,
64-
)
65-
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
66-
),
67-
skipLoggingConfiguration = true,
57+
Backend.set(
58+
SparkBackend(
59+
sc = new SparkContext(
60+
SparkBackend.createSparkConf(
61+
appName = "Hail.TestNG",
62+
master = System.getProperty("hail.master"),
63+
local = "local[2]",
64+
blockSize = 0,
65+
)
66+
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
67+
),
68+
skipLoggingConfiguration = true,
69+
)
6870
)
69-
HailSuite.hc_ = HailContext(backend)
71+
HailSuite.hc_ = HailContext.getOrCreate
7072
}
7173

7274
@BeforeClass
7375
def setupExecuteContext(): Unit = {
74-
val backend = HailSuite.hc_.backend.asSpark
76+
val backend = Backend.get.asSpark
7577
val conf = new Configuration(backend.sc.hadoopConfiguration)
7678
val fs = new HadoopFS(new SerializableHadoopConfiguration(conf))
7779
val pool = RegionPool()
@@ -103,7 +105,7 @@ class HailSuite extends TestNGSuite with TestUtils {
103105

104106
hadoop.fs.FileSystem.closeAll()
105107

106-
if (HailSuite.hc_.backend.asSpark.sc.isStopped)
108+
if (SparkBackend.sparkContext.isStopped)
107109
throw new RuntimeException(s"'${context.getName}' stopped spark context!")
108110
}
109111

hail/python/hail/backend/local_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
append,
7979
skip_logging_configuration,
8080
)
81-
jhc = hail_package.HailContext.apply(jbackend)
81+
jhc = hail_package.HailContext.apply()
8282

8383
super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
8484
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration

hail/python/hail/backend/spark_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
skip_logging_configuration,
118118
min_block_size,
119119
)
120-
jhc = hail_package.HailContext.getOrCreate(jbackend)
120+
jhc = hail_package.HailContext.getOrCreate()
121121
else:
122122
jbackend = hail_package.backend.spark.SparkBackend.apply(
123123
jsc,
@@ -130,7 +130,7 @@ def __init__(
130130
skip_logging_configuration,
131131
min_block_size,
132132
)
133-
jhc = hail_package.HailContext.apply(jbackend)
133+
jhc = hail_package.HailContext.apply()
134134

135135
self._jsc = jbackend.sc()
136136
if sc:

0 commit comments

Comments
 (0)