Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package is.hail.backend

import is.hail.asm4s.HailClassLoader
import is.hail.backend.Backend.PartitionFn
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
Expand Down Expand Up @@ -44,43 +45,31 @@ object Backend {
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

// Currently required by `BackendUtils.collectDArray`
private[this] var instance: Backend = _

def set(b: Backend): Unit =
synchronized { instance = b }

def get: Backend =
synchronized {
assert(instance != null)
instance
}
type PartitionFn = (Array[Byte], Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
}

abstract class BroadcastValue[T] { def value: T }

trait BackendContext {
def executionCache: ExecutionCache
}

abstract class Backend extends Closeable {

def defaultParallelism: Int

def canExecuteParallelTasksOnDriver: Boolean = true
abstract class DriverRuntimeContext {

def broadcast[T: ClassTag](value: T): BroadcastValue[T]
def executionCache: ExecutionCache

def parallelizeAndComputeWithIndex(
backendContext: BackendContext,
fs: FS,
def mapCollectPartitions(
globals: Array[Byte],
contexts: IndexedSeq[Array[Byte]],
stageIdentifier: String,
dependency: Option[TableStageDependency] = None,
partitions: Option[IndexedSeq[Int]] = None,
)(
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
f: PartitionFn
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)])
}

abstract class Backend extends Closeable {

def defaultParallelism: Int

def broadcast[T: ClassTag](value: T): BroadcastValue[T]

def asSpark(implicit E: Enclosing): SparkBackend =
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")
Expand Down Expand Up @@ -118,5 +107,5 @@ abstract class Backend extends Closeable {

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]

def backendContext(ctx: ExecuteContext): BackendContext
def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext
}
123 changes: 59 additions & 64 deletions hail/hail/src/is/hail/backend/BackendUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@ package is.hail.backend

import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.lowering.TableStageDependency
import is.hail.io.fs._
import is.hail.services._
import is.hail.utils._

import scala.util.control.NonFatal

object BackendUtils {
type F = AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]]
}
Expand All @@ -27,84 +23,83 @@ class BackendUtils(
def getModule(id: String): (HailClassLoader, FS, HailTaskContext, Region) => F = loadedModules(id)

def collectDArray(
backendContext: BackendContext,
theDriverHailClassLoader: HailClassLoader,
fs: FS,
ctx: DriverRuntimeContext,
modID: String,
contexts: Array[Array[Byte]],
globals: Array[Byte],
stageName: String,
tsd: Option[TableStageDependency],
): Array[Array[Byte]] = {
val (failureOpt, results) = runCDA(ctx, globals, contexts, None, modID, stageName, tsd)
failureOpt.foreach(throw _)
Array.tabulate[Array[Byte]](results.length)(results(_)._1)
}

def ccCollectDArray(
ctx: DriverRuntimeContext,
modID: String,
contexts: Array[Array[Byte]],
globals: Array[Byte],
stageName: String,
semhash: Option[SemanticHash.Type],
semhash: SemanticHash.Type,
tsd: Option[TableStageDependency],
): Array[Array[Byte]] = {

val cachedResults =
semhash
.map { s =>
log.info(s"[collectDArray|$stageName]: querying cache for $s")
val cachedResults = backendContext.executionCache.lookup(s)
log.info(s"[collectDArray|$stageName]: found ${cachedResults.length} entries for $s.")
cachedResults
}
.getOrElse(IndexedSeq.empty)
val cachedResults = ctx.executionCache.lookup(semhash)
log.info(s"$stageName: found ${cachedResults.length} entries for $semhash.")

val remainingPartitions =
contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))
val todo =
contexts
.indices
.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))

val backend = Backend.get
val mod = getModule(modID)
val t = System.nanoTime()
val (failureOpt, successes) =
remainingPartitions match {
todo match {
case Seq() =>
(None, IndexedSeq.empty)
case Seq(k) if backend.canExecuteParallelTasksOnDriver =>
try
using(new LocalTaskContext(k, 0)) { htc =>
using(htc.getRegionPool().getRegion()) { r =>
val f = mod(theDriverHailClassLoader, fs, htc, r)
val res = retryTransientErrors(f(r, contexts(k), globals))
(None, FastSeq(res -> k))
}
}
catch {
case NonFatal(ex) =>
(Some(ex), IndexedSeq.empty)
}

case partitions =>
val globalsBC = backend.broadcast(globals)
val fsConfigBC = backend.broadcast(fs.getConfiguration())
backend.parallelizeAndComputeWithIndex(
backendContext,
fs,
contexts,
stageName,
tsd,
Some(partitions),
) { (ctx, htc, theHailClassLoader, fs) =>
val fsConfig = fsConfigBC.value
val gs = globalsBC.value
fs.setConfiguration(fsConfig)
htc.getRegionPool().scopedRegion { region =>
mod(theHailClassLoader, fs, htc, region)(region, ctx, gs)
}
}
runCDA(ctx, globals, contexts, Some(partitions), modID, stageName, tsd)
}

log.info(
s"[collectDArray|$stageName]: executed ${remainingPartitions.length} tasks in ${formatTime(System.nanoTime() - t)}"
)
val results = merge[(Array[Byte], Int)](cachedResults, successes, _._2 < _._2)

val results =
merge[(Array[Byte], Int)](
cachedResults,
successes.sortBy(_._2),
_._2 < _._2,
)
ctx.executionCache.put(semhash, results)
log.info(s"$stageName: cached ${results.length} entries for $semhash.")

semhash.foreach(s => backendContext.executionCache.put(s, results))
failureOpt.foreach(throw _)
Array.tabulate[Array[Byte]](results.length)(results(_)._1)
}

private[this] def runCDA(
rtx: DriverRuntimeContext,
globals: Array[Byte],
contexts: Array[Array[Byte]],
partitions: Option[IndexedSeq[Int]],
modID: String,
stageName: String,
tsd: Option[TableStageDependency],
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {

val mod = getModule(modID)
val start = System.nanoTime()

val r = rtx.mapCollectPartitions(
globals,
contexts,
stageName,
tsd,
partitions,
) { (gs, ctx, htc, theHailClassLoader, fs) =>
htc.getRegionPool().scopedRegion { region =>
mod(theHailClassLoader, fs, htc, region)(region, ctx, gs)
}
}

val elapsed = System.nanoTime() - start
val nTasks = partitions.map(_.length).getOrElse(contexts.length)
log.info(s"$stageName: executed $nTasks tasks in ${formatTime(elapsed)}")

results.map(_._1).toArray
r
}
}
7 changes: 1 addition & 6 deletions hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
jobConfig,
)

Backend.set(backend)

// FIXME: when can the classloader be shared? (optimizer benefits!)
try runRpc(
Env(
Expand All @@ -207,10 +205,7 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
payload,
)
)
finally {
Backend.set(null)
backend.close()
}
finally backend.close()
}
}

Expand Down
3 changes: 0 additions & 3 deletions hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
)

Backend.set(backend)

def pyFs: FS =
synchronized(tmpFileManager.fs)

Expand Down Expand Up @@ -307,7 +305,6 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
compiledCodeCache.clear()
irCache.clear()
coercerCache.clear()
Backend.set(null)
backend.close()
IRFunctionRegistry.clearUserFunctions()
}
Expand Down
80 changes: 50 additions & 30 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
package is.hail.backend.local

import is.hail.CancellingExecutorService
import is.hail.asm4s._
import is.hail.backend._
import is.hail.backend.Backend.PartitionFn
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.types._
import is.hail.types.physical.PTuple
import is.hail.utils._
import is.hail.utils.compat.immutable.ArraySeq

import scala.reflect.ClassTag
import scala.util.control.NonFatal

import java.io.PrintWriter

import com.fasterxml.jackson.core.StreamReadConstraints
import com.google.common.util.concurrent.MoreExecutors

class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable

Expand Down Expand Up @@ -53,9 +52,6 @@ object LocalBackend extends Backend {
this
}

private case class Context(hcl: HailClassLoader, override val executionCache: ExecutionCache)
extends BackendContext

def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value)

private[this] var stageIdx: Int = 0
Expand All @@ -67,32 +63,54 @@ object LocalBackend extends Backend {
current
}

override def parallelizeAndComputeWithIndex(
ctx: BackendContext,
fs: FS,
contexts: IndexedSeq[Array[Byte]],
stageIdentifier: String,
dependency: Option[TableStageDependency],
partitions: Option[IndexedSeq[Int]],
)(
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {

val stageId = nextStageId()
val hcl = ctx.asInstanceOf[Context].hcl
runAllKeepFirstError(new CancellingExecutorService(MoreExecutors.newDirectExecutorService())) {
partitions.getOrElse(contexts.indices).map { i =>
(
() => using(new LocalTaskContext(i, stageId))(f(contexts(i), _, hcl, fs)),
i,
)
override def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext = {
new DriverRuntimeContext {

override val executionCache: ExecutionCache =
ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.localTmpdir)

override def mapCollectPartitions(
globals: Array[Byte],
contexts: IndexedSeq[Array[Byte]],
stageIdentifier: String,
dependency: Option[TableStageDependency],
partitions: Option[IndexedSeq[Int]],
)(
f: PartitionFn
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {

val todo: IndexedSeq[Int] =
partitions.getOrElse(contexts.indices)

val results = ArraySeq.newBuilder[(Array[Byte], Int)]
results.sizeHint(todo.length)

var failure: Option[Throwable] =
None

val stageId = nextStageId()

try
for (idx <- todo) {
using(new LocalTaskContext(idx, stageId)) { tx =>
results += {
(
f(globals, contexts(idx), tx, ctx.theHailClassLoader, ctx.fs),
idx,
)
}
}
}
catch {
case NonFatal(t) =>
failure = Some(t)
}

(failure, results.result())
}
}
}

override def backendContext(ctx: ExecuteContext): BackendContext =
Context(ctx.theHailClassLoader, ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir))

def defaultParallelism: Int = 1

def close(): Unit =
Expand All @@ -116,7 +134,9 @@ object LocalBackend extends Backend {
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
if (ctx.flags.isDefined(ExecutionCache.Flags.UseFastRestarts))
ctx.irMetadata.semhash = SemanticHash(ctx, ir)

val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
Loading