Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions hail/hail/src/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
package is.hail

import is.hail.backend.Backend
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.io.fs.FS
import is.hail.utils._

import org.apache.spark._

object HailContext {

private var theContext: HailContext = _

def get: HailContext = synchronized {
assert(TaskContext.get() == null, "HailContext not available on worker")
assert(theContext != null, "HailContext not initialized")
theContext
}

def checkJavaVersion(): Unit = {
val javaVersion = raw"(\d+)\.(\d+)\.(\d+).*".r
val versionString = System.getProperty("java.version")
Expand All @@ -36,13 +27,13 @@ object HailContext {
}
}

def getOrCreate(backend: Backend): HailContext =
def getOrCreate: HailContext =
synchronized {
if (theContext != null) theContext
else HailContext(backend)
else apply
}

def apply(backend: Backend): HailContext = synchronized {
def apply: HailContext = synchronized {
require(theContext == null)
checkJavaVersion()

Expand All @@ -60,7 +51,7 @@ object HailContext {
)
}

theContext = new HailContext(backend)
theContext = new HailContext

info(s"Running Hail version $HAIL_PRETTY_VERSION")

Expand All @@ -70,14 +61,11 @@ object HailContext {
def stop(): Unit =
synchronized {
IRFunctionRegistry.clearUserFunctions()
theContext.backend.close()
theContext = null
}
}

class HailContext private (
var backend: Backend
) {
class HailContext {
def stop(): Unit = HailContext.stop()

private[this] def fileAndLineCounts(
Expand Down
12 changes: 12 additions & 0 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ object Backend {
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

// Currently required by `BackendUtils.collectDArray`
private[this] var instance: Backend = _

def set(b: Backend): Unit =
synchronized { instance = b }

def get: Backend =
synchronized {
assert(instance != null)
instance
}
Comment on lines +47 to +57
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary before this is removed in #15107

}

abstract class BroadcastValue[T] { def value: T }
Expand Down
3 changes: 1 addition & 2 deletions hail/hail/src/is/hail/backend/BackendUtils.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package is.hail.backend

import is.hail.HailContext
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.local.LocalTaskContext
Expand Down Expand Up @@ -52,7 +51,7 @@ class BackendUtils(
val remainingPartitions =
contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))

val backend = HailContext.get.backend
val backend = Backend.get
val mod = getModule(modID)
val t = System.nanoTime()
val (failureOpt, successes) =
Expand Down
6 changes: 1 addition & 5 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ class ExecuteContext(
f: (HailClassLoader, FS, HailTaskContext, Region) => T
)(implicit E: Enclosing
): T =
using(new LocalTaskContext(0, 0)) { tc =>
time {
f(theHailClassLoader, fs, tc, r)
}
}
using(new LocalTaskContext(0, 0))(tc => time(f(theHailClassLoader, fs, tc, r)))

def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String =
tempFileManager.newTmpPath(if (local) localTmpdir else tmpdir, prefix, extension)
Expand Down
9 changes: 7 additions & 2 deletions hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
jobConfig,
)

HailContext(backend)
HailContext.getOrCreate
Backend.set(backend)
log.info("HailContext initialized.")

// FIXME: when can the classloader be shared? (optimizer benefits!)
Expand All @@ -211,7 +212,11 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
payload,
)
)
finally HailContext.stop()
finally {
HailContext.stop()
Backend.set(null)
backend.close()
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
)

Backend.set(backend)

def pyFs: FS =
synchronized(tmpFileManager.fs)

Expand Down Expand Up @@ -284,6 +286,8 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
compiledCodeCache.clear()
irCache.clear()
coercerCache.clear()
Backend.set(null)
backend.close()
}

private[this] def removeReference(name: String): Unit =
Expand Down
3 changes: 1 addition & 2 deletions hail/hail/src/is/hail/backend/service/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ object Worker {
timer.end("readInputs")
timer.start("executeFunction")

// FIXME: workers should not have backends, but some things do need hail contexts
HailContext.getOrCreate(new ServiceBackend(null, null, null, null, null))
HailContext.getOrCreate
val result =
try
using(new ServiceTaskContext(i)) { htc =>
Expand Down
34 changes: 20 additions & 14 deletions hail/hail/test/src/is/hail/HailSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,26 @@ class HailSuite extends TestNGSuite with TestUtils {
def setupHailContext(): Unit = {
Logging.configureLogging("/tmp/hail.log", quiet = false, append = false)
RVD.CheckRvdKeyOrderingForTesting = true
val backend = SparkBackend(
sc = new SparkContext(
SparkBackend.createSparkConf(
appName = "Hail.TestNG",
master = System.getProperty("hail.master"),
local = "local[2]",
blockSize = 0,
)
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
),
skipLoggingConfiguration = true,
Backend.set(
SparkBackend(
sc = new SparkContext(
SparkBackend.createSparkConf(
appName = "Hail.TestNG",
master = System.getProperty("hail.master"),
local = "local[2]",
blockSize = 0,
)
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
),
skipLoggingConfiguration = true,
)
)
HailSuite.hc_ = HailContext(backend)
HailSuite.hc_ = HailContext.getOrCreate
}

@BeforeClass
def setupExecuteContext(): Unit = {
val backend = HailSuite.hc_.backend.asSpark
val backend = Backend.get.asSpark
val conf = new Configuration(backend.sc.hadoopConfiguration)
val fs = new HadoopFS(new SerializableHadoopConfiguration(conf))
val pool = RegionPool()
Expand Down Expand Up @@ -103,12 +105,16 @@ class HailSuite extends TestNGSuite with TestUtils {

hadoop.fs.FileSystem.closeAll()

if (HailSuite.hc_.backend.asSpark.sc.isStopped)
if (SparkBackend.sparkContext.isStopped)
throw new RuntimeException(s"'${context.getName}' stopped spark context!")
}

@AfterSuite
def tearDownHailContext(): Unit = {
val backend = Backend.get
Backend.set(null)
backend.close()

HailSuite.hc_.stop()
HailSuite.hc_ = null
}
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
append,
skip_logging_configuration,
)
jhc = hail_package.HailContext.apply(jbackend)
jhc = hail_package.HailContext.apply()

super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
skip_logging_configuration,
min_block_size,
)
jhc = hail_package.HailContext.getOrCreate(jbackend)
jhc = hail_package.HailContext.getOrCreate()
else:
jbackend = hail_package.backend.spark.SparkBackend.apply(
jsc,
Expand All @@ -130,7 +130,7 @@ def __init__(
skip_logging_configuration,
min_block_size,
)
jhc = hail_package.HailContext.apply(jbackend)
jhc = hail_package.HailContext.apply()

self._jsc = jbackend.sc()
if sc:
Expand Down