diff --git a/hail/hail/src/is/hail/HailContext.scala b/hail/hail/src/is/hail/HailContext.scala index a0ca2290f84..a2080d12e2e 100644 --- a/hail/hail/src/is/hail/HailContext.scala +++ b/hail/hail/src/is/hail/HailContext.scala @@ -1,23 +1,14 @@ package is.hail -import is.hail.backend.Backend import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.io.fs.FS import is.hail.utils._ -import org.apache.spark._ - object HailContext { private var theContext: HailContext = _ - def get: HailContext = synchronized { - assert(TaskContext.get() == null, "HailContext not available on worker") - assert(theContext != null, "HailContext not initialized") - theContext - } - def checkJavaVersion(): Unit = { val javaVersion = raw"(\d+)\.(\d+)\.(\d+).*".r val versionString = System.getProperty("java.version") @@ -36,13 +27,13 @@ object HailContext { } } - def getOrCreate(backend: Backend): HailContext = + def getOrCreate: HailContext = synchronized { if (theContext != null) theContext - else HailContext(backend) + else apply } - def apply(backend: Backend): HailContext = synchronized { + def apply: HailContext = synchronized { require(theContext == null) checkJavaVersion() @@ -60,7 +51,7 @@ object HailContext { ) } - theContext = new HailContext(backend) + theContext = new HailContext info(s"Running Hail version $HAIL_PRETTY_VERSION") @@ -70,14 +61,11 @@ object HailContext { def stop(): Unit = synchronized { IRFunctionRegistry.clearUserFunctions() - theContext.backend.close() theContext = null } } -class HailContext private ( - var backend: Backend -) { +class HailContext { def stop(): Unit = HailContext.stop() private[this] def fileAndLineCounts( diff --git a/hail/hail/src/is/hail/backend/Backend.scala b/hail/hail/src/is/hail/backend/Backend.scala index 71a0ef4e75e..3b8cb397340 100644 --- a/hail/hail/src/is/hail/backend/Backend.scala +++ b/hail/hail/src/is/hail/backend/Backend.scala @@ -43,6 +43,18 @@ object Backend { assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } + + // Currently required by `BackendUtils.collectDArray` + private[this] var instance: Backend = _ + + def set(b: Backend): Unit = + synchronized { instance = b } + + def get: Backend = + synchronized { + assert(instance != null) + instance + } } abstract class BroadcastValue[T] { def value: T } diff --git a/hail/hail/src/is/hail/backend/BackendUtils.scala b/hail/hail/src/is/hail/backend/BackendUtils.scala index 2d18d6c9b35..e285b54ea82 100644 --- a/hail/hail/src/is/hail/backend/BackendUtils.scala +++ b/hail/hail/src/is/hail/backend/BackendUtils.scala @@ -1,6 +1,5 @@ package is.hail.backend -import is.hail.HailContext import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.local.LocalTaskContext @@ -52,7 +51,7 @@ class BackendUtils( val remainingPartitions = contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2)) - val backend = HailContext.get.backend + val backend = Backend.get val mod = getModule(modID) val t = System.nanoTime() val (failureOpt, successes) = diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index 8b929373e8e..e03577c61d6 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -148,11 +148,7 @@ class ExecuteContext( f: (HailClassLoader, FS, HailTaskContext, Region) => T )(implicit E: Enclosing ): T = - using(new LocalTaskContext(0, 0)) { tc => - time { - f(theHailClassLoader, fs, tc, r) - } - } + using(new LocalTaskContext(0, 0))(tc => time(f(theHailClassLoader, fs, tc, r))) def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String = tempFileManager.newTmpPath(if (local) localTmpdir else tmpdir, prefix, extension) diff --git a/hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala b/hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala index e903995f64a..66eeb910af5 100644 --- a/hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala +++ b/hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala @@ -194,7 +194,8 @@ object BatchQueryDriver extends HttpLikeRpc with Logging { jobConfig, ) - HailContext(backend) + HailContext.getOrCreate + Backend.set(backend) log.info("HailContext initialized.") // FIXME: when can the classloader be shared? (optimizer benefits!) @@ -211,7 +212,11 @@ object BatchQueryDriver extends HttpLikeRpc with Logging { payload, ) ) - finally HailContext.stop() + finally { + HailContext.stop() + Backend.set(null) + backend.close() + } } } diff --git a/hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala b/hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala index 0dcc5a8a0e8..47c67ab1ac9 100644 --- a/hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala +++ b/hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala @@ -55,6 +55,8 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable { newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) ) + Backend.set(backend) + def pyFs: FS = synchronized(tmpFileManager.fs) @@ -284,6 +286,8 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable { compiledCodeCache.clear() irCache.clear() coercerCache.clear() + Backend.set(null) + backend.close() } private[this] def removeReference(name: String): Unit = diff --git a/hail/hail/src/is/hail/backend/service/Worker.scala b/hail/hail/src/is/hail/backend/service/Worker.scala index 42d461d38e9..0d97cbee039 100644 --- a/hail/hail/src/is/hail/backend/service/Worker.scala +++ b/hail/hail/src/is/hail/backend/service/Worker.scala @@ -177,8 +177,7 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - // FIXME: workers should not have backends, but some things do need hail contexts - HailContext.getOrCreate(new ServiceBackend(null, null, null, null, null)) + HailContext.getOrCreate val result = try using(new ServiceTaskContext(i)) { htc => diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index 97e6d9629d2..ea0295ab739 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -54,24 +54,26 @@ class HailSuite extends TestNGSuite with TestUtils { def setupHailContext(): Unit = { Logging.configureLogging("/tmp/hail.log", quiet = false, append = false) RVD.CheckRvdKeyOrderingForTesting = true - val backend = SparkBackend( - sc = new SparkContext( - SparkBackend.createSparkConf( - appName = "Hail.TestNG", - master = System.getProperty("hail.master"), - local = "local[2]", - blockSize = 0, - ) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - ), - skipLoggingConfiguration = true, + Backend.set( + SparkBackend( + sc = new SparkContext( + SparkBackend.createSparkConf( + appName = "Hail.TestNG", + master = System.getProperty("hail.master"), + local = "local[2]", + blockSize = 0, + ) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + ), + skipLoggingConfiguration = true, + ) ) - HailSuite.hc_ = HailContext(backend) + HailSuite.hc_ = HailContext.getOrCreate } @BeforeClass def setupExecuteContext(): Unit = { - val backend = HailSuite.hc_.backend.asSpark + val backend = Backend.get.asSpark val conf = new Configuration(backend.sc.hadoopConfiguration) val fs = new HadoopFS(new SerializableHadoopConfiguration(conf)) val pool = RegionPool() @@ -103,12 +105,16 @@ class HailSuite extends TestNGSuite with TestUtils { hadoop.fs.FileSystem.closeAll() - if (HailSuite.hc_.backend.asSpark.sc.isStopped) + if (SparkBackend.sparkContext.isStopped) throw new RuntimeException(s"'${context.getName}' stopped spark context!") } @AfterSuite def tearDownHailContext(): Unit = { + val backend = Backend.get + Backend.set(null) + backend.close() + HailSuite.hc_.stop() HailSuite.hc_ = null } diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 289bd53024b..524b598c7c3 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -78,7 +78,7 @@ def __init__( append, skip_logging_configuration, ) - jhc = hail_package.HailContext.apply(jbackend) + jhc = hail_package.HailContext.apply() super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir) self.gcs_requester_pays_configuration = gcs_requester_pays_configuration diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index c3dcd2195cf..10a06b6fb52 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -117,7 +117,7 @@ def __init__( skip_logging_configuration, min_block_size, ) - jhc = hail_package.HailContext.getOrCreate(jbackend) + jhc = hail_package.HailContext.getOrCreate() else: jbackend = hail_package.backend.spark.SparkBackend.apply( jsc, @@ -130,7 +130,7 @@ def __init__( skip_logging_configuration, min_block_size, ) - jhc = hail_package.HailContext.apply(jbackend) + jhc = hail_package.HailContext.apply() self._jsc = jbackend.sc() if sc: