Skip to content

Commit e829711

Browse files
authored
[query] remove global Backend field (#15107)
## Change Description This PR refactors the Backend interface to improve separation of concerns between driver and worker contexts. The key changes include: 1. Removes the static `Backend.instance` singleton pattern in favour of explicit context passing 2. Repurposes `BackendContext`​ as `DriverRuntimeContext` to implement `mapCollectPartitions`​ (formally `parallelizeAndComputeWithIndex`) 3. Removes `canExecuteParallelTasksOnDriver`​ from `Backend`​ - this is now an implementation detail in `mapCollectPartitions`​. 4. Broadcasts globals in a separate file for `ServiceBackend`​ 5. Disable all semantic-hash code when it's not featured on ## Security Assessment This change cannot impact the Hail Batch instance as deployed by Broad Institute in GCP
1 parent c9e644b commit e829711

File tree

14 files changed

+613
-569
lines changed

14 files changed

+613
-569
lines changed

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package is.hail.backend
22

33
import is.hail.asm4s.HailClassLoader
4+
import is.hail.backend.Backend.PartitionFn
45
import is.hail.backend.spark.SparkBackend
56
import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader}
67
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
@@ -44,43 +45,31 @@ object Backend {
4445
codec.encode(ctx, elementType, t.loadField(off, 0), os)
4546
}
4647

47-
// Currently required by `BackendUtils.collectDArray`
48-
private[this] var instance: Backend = _
49-
50-
def set(b: Backend): Unit =
51-
synchronized { instance = b }
52-
53-
def get: Backend =
54-
synchronized {
55-
assert(instance != null)
56-
instance
57-
}
48+
type PartitionFn = (Array[Byte], Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
5849
}
5950

6051
abstract class BroadcastValue[T] { def value: T }
6152

62-
trait BackendContext {
63-
def executionCache: ExecutionCache
64-
}
65-
66-
abstract class Backend extends Closeable {
67-
68-
def defaultParallelism: Int
69-
70-
def canExecuteParallelTasksOnDriver: Boolean = true
53+
abstract class DriverRuntimeContext {
7154

72-
def broadcast[T: ClassTag](value: T): BroadcastValue[T]
55+
def executionCache: ExecutionCache
7356

74-
def parallelizeAndComputeWithIndex(
75-
backendContext: BackendContext,
76-
fs: FS,
57+
def mapCollectPartitions(
58+
globals: Array[Byte],
7759
contexts: IndexedSeq[Array[Byte]],
7860
stageIdentifier: String,
7961
dependency: Option[TableStageDependency] = None,
8062
partitions: Option[IndexedSeq[Int]] = None,
8163
)(
82-
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
64+
f: PartitionFn
8365
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)])
66+
}
67+
68+
abstract class Backend extends Closeable {
69+
70+
def defaultParallelism: Int
71+
72+
def broadcast[T: ClassTag](value: T): BroadcastValue[T]
8473

8574
def asSpark(implicit E: Enclosing): SparkBackend =
8675
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")
@@ -118,5 +107,5 @@ abstract class Backend extends Closeable {
118107

119108
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
120109

121-
def backendContext(ctx: ExecuteContext): BackendContext
110+
def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext
122111
}

hail/hail/src/is/hail/backend/BackendUtils.scala

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@ package is.hail.backend
22

33
import is.hail.annotations.Region
44
import is.hail.asm4s._
5-
import is.hail.backend.local.LocalTaskContext
65
import is.hail.expr.ir.analyses.SemanticHash
76
import is.hail.expr.ir.lowering.TableStageDependency
87
import is.hail.io.fs._
9-
import is.hail.services._
108
import is.hail.utils._
119

12-
import scala.util.control.NonFatal
13-
1410
object BackendUtils {
1511
type F = AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]]
1612
}
@@ -27,84 +23,83 @@ class BackendUtils(
2723
def getModule(id: String): (HailClassLoader, FS, HailTaskContext, Region) => F = loadedModules(id)
2824

2925
def collectDArray(
30-
backendContext: BackendContext,
31-
theDriverHailClassLoader: HailClassLoader,
32-
fs: FS,
26+
ctx: DriverRuntimeContext,
27+
modID: String,
28+
contexts: Array[Array[Byte]],
29+
globals: Array[Byte],
30+
stageName: String,
31+
tsd: Option[TableStageDependency],
32+
): Array[Array[Byte]] = {
33+
val (failureOpt, results) = runCDA(ctx, globals, contexts, None, modID, stageName, tsd)
34+
failureOpt.foreach(throw _)
35+
Array.tabulate[Array[Byte]](results.length)(results(_)._1)
36+
}
37+
38+
def ccCollectDArray(
39+
ctx: DriverRuntimeContext,
3340
modID: String,
3441
contexts: Array[Array[Byte]],
3542
globals: Array[Byte],
3643
stageName: String,
37-
semhash: Option[SemanticHash.Type],
44+
semhash: SemanticHash.Type,
3845
tsd: Option[TableStageDependency],
3946
): Array[Array[Byte]] = {
4047

41-
val cachedResults =
42-
semhash
43-
.map { s =>
44-
log.info(s"[collectDArray|$stageName]: querying cache for $s")
45-
val cachedResults = backendContext.executionCache.lookup(s)
46-
log.info(s"[collectDArray|$stageName]: found ${cachedResults.length} entries for $s.")
47-
cachedResults
48-
}
49-
.getOrElse(IndexedSeq.empty)
48+
val cachedResults = ctx.executionCache.lookup(semhash)
49+
log.info(s"$stageName: found ${cachedResults.length} entries for $semhash.")
5050

51-
val remainingPartitions =
52-
contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))
51+
val todo =
52+
contexts
53+
.indices
54+
.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))
5355

54-
val backend = Backend.get
55-
val mod = getModule(modID)
56-
val t = System.nanoTime()
5756
val (failureOpt, successes) =
58-
remainingPartitions match {
57+
todo match {
5958
case Seq() =>
6059
(None, IndexedSeq.empty)
61-
case Seq(k) if backend.canExecuteParallelTasksOnDriver =>
62-
try
63-
using(new LocalTaskContext(k, 0)) { htc =>
64-
using(htc.getRegionPool().getRegion()) { r =>
65-
val f = mod(theDriverHailClassLoader, fs, htc, r)
66-
val res = retryTransientErrors(f(r, contexts(k), globals))
67-
(None, FastSeq(res -> k))
68-
}
69-
}
70-
catch {
71-
case NonFatal(ex) =>
72-
(Some(ex), IndexedSeq.empty)
73-
}
60+
7461
case partitions =>
75-
val globalsBC = backend.broadcast(globals)
76-
val fsConfigBC = backend.broadcast(fs.getConfiguration())
77-
backend.parallelizeAndComputeWithIndex(
78-
backendContext,
79-
fs,
80-
contexts,
81-
stageName,
82-
tsd,
83-
Some(partitions),
84-
) { (ctx, htc, theHailClassLoader, fs) =>
85-
val fsConfig = fsConfigBC.value
86-
val gs = globalsBC.value
87-
fs.setConfiguration(fsConfig)
88-
htc.getRegionPool().scopedRegion { region =>
89-
mod(theHailClassLoader, fs, htc, region)(region, ctx, gs)
90-
}
91-
}
62+
runCDA(ctx, globals, contexts, Some(partitions), modID, stageName, tsd)
9263
}
9364

94-
log.info(
95-
s"[collectDArray|$stageName]: executed ${remainingPartitions.length} tasks in ${formatTime(System.nanoTime() - t)}"
96-
)
65+
val results = merge[(Array[Byte], Int)](cachedResults, successes, _._2 < _._2)
9766

98-
val results =
99-
merge[(Array[Byte], Int)](
100-
cachedResults,
101-
successes.sortBy(_._2),
102-
_._2 < _._2,
103-
)
67+
ctx.executionCache.put(semhash, results)
68+
log.info(s"$stageName: cached ${results.length} entries for $semhash.")
10469

105-
semhash.foreach(s => backendContext.executionCache.put(s, results))
10670
failureOpt.foreach(throw _)
71+
Array.tabulate[Array[Byte]](results.length)(results(_)._1)
72+
}
73+
74+
private[this] def runCDA(
75+
rtx: DriverRuntimeContext,
76+
globals: Array[Byte],
77+
contexts: Array[Array[Byte]],
78+
partitions: Option[IndexedSeq[Int]],
79+
modID: String,
80+
stageName: String,
81+
tsd: Option[TableStageDependency],
82+
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {
83+
84+
val mod = getModule(modID)
85+
val start = System.nanoTime()
86+
87+
val r = rtx.mapCollectPartitions(
88+
globals,
89+
contexts,
90+
stageName,
91+
tsd,
92+
partitions,
93+
) { (gs, ctx, htc, theHailClassLoader, fs) =>
94+
htc.getRegionPool().scopedRegion { region =>
95+
mod(theHailClassLoader, fs, htc, region)(region, ctx, gs)
96+
}
97+
}
98+
99+
val elapsed = System.nanoTime() - start
100+
val nTasks = partitions.map(_.length).getOrElse(contexts.length)
101+
log.info(s"$stageName: executed $nTasks tasks in ${formatTime(elapsed)}")
107102

108-
results.map(_._1).toArray
103+
r
109104
}
110105
}

hail/hail/src/is/hail/backend/driver/BatchQueryDriver.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
191191
jobConfig,
192192
)
193193

194-
Backend.set(backend)
195-
196194
// FIXME: when can the classloader be shared? (optimizer benefits!)
197195
try runRpc(
198196
Env(
@@ -207,10 +205,7 @@ object BatchQueryDriver extends HttpLikeRpc with Logging {
207205
payload,
208206
)
209207
)
210-
finally {
211-
Backend.set(null)
212-
backend.close()
213-
}
208+
finally backend.close()
214209
}
215210
}
216211

hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
5454
newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
5555
)
5656

57-
Backend.set(backend)
58-
5957
def pyFs: FS =
6058
synchronized(tmpFileManager.fs)
6159

@@ -307,7 +305,6 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
307305
compiledCodeCache.clear()
308306
irCache.clear()
309307
coercerCache.clear()
310-
Backend.set(null)
311308
backend.close()
312309
IRFunctionRegistry.clearUserFunctions()
313310
}

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
package is.hail.backend.local
22

3-
import is.hail.CancellingExecutorService
4-
import is.hail.asm4s._
53
import is.hail.backend._
4+
import is.hail.backend.Backend.PartitionFn
65
import is.hail.expr.Validate
76
import is.hail.expr.ir._
87
import is.hail.expr.ir.analyses.SemanticHash
98
import is.hail.expr.ir.lowering._
10-
import is.hail.io.fs._
119
import is.hail.types._
1210
import is.hail.types.physical.PTuple
1311
import is.hail.utils._
12+
import is.hail.utils.compat.immutable.ArraySeq
1413

1514
import scala.reflect.ClassTag
15+
import scala.util.control.NonFatal
1616

1717
import java.io.PrintWriter
1818

1919
import com.fasterxml.jackson.core.StreamReadConstraints
20-
import com.google.common.util.concurrent.MoreExecutors
2120

2221
class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable
2322

@@ -53,9 +52,6 @@ object LocalBackend extends Backend {
5352
this
5453
}
5554

56-
private case class Context(hcl: HailClassLoader, override val executionCache: ExecutionCache)
57-
extends BackendContext
58-
5955
def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value)
6056

6157
private[this] var stageIdx: Int = 0
@@ -67,32 +63,54 @@ object LocalBackend extends Backend {
6763
current
6864
}
6965

70-
override def parallelizeAndComputeWithIndex(
71-
ctx: BackendContext,
72-
fs: FS,
73-
contexts: IndexedSeq[Array[Byte]],
74-
stageIdentifier: String,
75-
dependency: Option[TableStageDependency],
76-
partitions: Option[IndexedSeq[Int]],
77-
)(
78-
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
79-
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {
80-
81-
val stageId = nextStageId()
82-
val hcl = ctx.asInstanceOf[Context].hcl
83-
runAllKeepFirstError(new CancellingExecutorService(MoreExecutors.newDirectExecutorService())) {
84-
partitions.getOrElse(contexts.indices).map { i =>
85-
(
86-
() => using(new LocalTaskContext(i, stageId))(f(contexts(i), _, hcl, fs)),
87-
i,
88-
)
66+
override def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext = {
67+
new DriverRuntimeContext {
68+
69+
override val executionCache: ExecutionCache =
70+
ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.localTmpdir)
71+
72+
override def mapCollectPartitions(
73+
globals: Array[Byte],
74+
contexts: IndexedSeq[Array[Byte]],
75+
stageIdentifier: String,
76+
dependency: Option[TableStageDependency],
77+
partitions: Option[IndexedSeq[Int]],
78+
)(
79+
f: PartitionFn
80+
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {
81+
82+
val todo: IndexedSeq[Int] =
83+
partitions.getOrElse(contexts.indices)
84+
85+
val results = ArraySeq.newBuilder[(Array[Byte], Int)]
86+
results.sizeHint(todo.length)
87+
88+
var failure: Option[Throwable] =
89+
None
90+
91+
val stageId = nextStageId()
92+
93+
try
94+
for (idx <- todo) {
95+
using(new LocalTaskContext(idx, stageId)) { tx =>
96+
results += {
97+
(
98+
f(globals, contexts(idx), tx, ctx.theHailClassLoader, ctx.fs),
99+
idx,
100+
)
101+
}
102+
}
103+
}
104+
catch {
105+
case NonFatal(t) =>
106+
failure = Some(t)
107+
}
108+
109+
(failure, results.result())
89110
}
90111
}
91112
}
92113

93-
override def backendContext(ctx: ExecuteContext): BackendContext =
94-
Context(ctx.theHailClassLoader, ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir))
95-
96114
def defaultParallelism: Int = 1
97115

98116
def close(): Unit =
@@ -116,7 +134,9 @@ object LocalBackend extends Backend {
116134
Validate(ir)
117135
val queryID = Backend.nextID()
118136
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
119-
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
137+
if (ctx.flags.isDefined(ExecutionCache.Flags.UseFastRestarts))
138+
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
139+
120140
val res = _jvmLowerAndExecute(ctx, ir)
121141
log.info(s"finished execution of query $queryID")
122142
res

0 commit comments

Comments
 (0)