Skip to content

Commit fc30260

Browse files
committed
[query] remove global Backend field
1 parent 56ea4b5 commit fc30260

File tree

12 files changed

+480
-459
lines changed

12 files changed

+480
-459
lines changed

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package is.hail.backend
22

33
import is.hail.asm4s.HailClassLoader
4+
import is.hail.backend.Backend.PartitionFn
45
import is.hail.backend.spark.SparkBackend
56
import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader}
67
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
@@ -44,43 +45,31 @@ object Backend {
4445
codec.encode(ctx, elementType, t.loadField(off, 0), os)
4546
}
4647

47-
// Currently required by `BackendUtils.collectDArray`
48-
private[this] var instance: Backend = _
49-
50-
def set(b: Backend): Unit =
51-
synchronized { instance = b }
52-
53-
def get: Backend =
54-
synchronized {
55-
assert(instance != null)
56-
instance
57-
}
48+
type PartitionFn = (Array[Byte], Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
5849
}
5950

6051
abstract class BroadcastValue[T] { def value: T }
6152

62-
trait BackendContext {
63-
def executionCache: ExecutionCache
64-
}
65-
66-
abstract class Backend extends Closeable {
67-
68-
def defaultParallelism: Int
69-
70-
def canExecuteParallelTasksOnDriver: Boolean = true
53+
abstract class DriverContext {
7154

72-
def broadcast[T: ClassTag](value: T): BroadcastValue[T]
55+
def executionCache: ExecutionCache
7356

74-
def parallelizeAndComputeWithIndex(
75-
backendContext: BackendContext,
76-
fs: FS,
57+
def collectDistributedArray(
58+
globals: Array[Byte],
7759
contexts: IndexedSeq[Array[Byte]],
7860
stageIdentifier: String,
7961
dependency: Option[TableStageDependency] = None,
8062
partitions: Option[IndexedSeq[Int]] = None,
8163
)(
82-
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
64+
f: PartitionFn
8365
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)])
66+
}
67+
68+
abstract class Backend extends Closeable {
69+
70+
def defaultParallelism: Int
71+
72+
def broadcast[T: ClassTag](value: T): BroadcastValue[T]
8473

8574
def asSpark(implicit E: Enclosing): SparkBackend =
8675
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")
@@ -118,5 +107,5 @@ abstract class Backend extends Closeable {
118107

119108
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
120109

121-
def backendContext(ctx: ExecuteContext): BackendContext
110+
def driverContext(ctx: ExecuteContext): DriverContext
122111
}

hail/hail/src/is/hail/backend/BackendUtils.scala

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ class BackendUtils(
2727
def getModule(id: String): (HailClassLoader, FS, HailTaskContext, Region) => F = loadedModules(id)
2828

2929
def collectDArray(
30-
backendContext: BackendContext,
31-
theDriverHailClassLoader: HailClassLoader,
32-
fs: FS,
30+
ctx: DriverContext,
3331
modID: String,
3432
contexts: Array[Array[Byte]],
3533
globals: Array[Byte],
@@ -42,7 +40,7 @@ class BackendUtils(
4240
semhash
4341
.map { s =>
4442
log.info(s"[collectDArray|$stageName]: querying cache for $s")
45-
val cachedResults = backendContext.executionCache.lookup(s)
43+
val cachedResults = ctx.executionCache.lookup(s)
4644
log.info(s"[collectDArray|$stageName]: found ${cachedResults.length} entries for $s.")
4745
cachedResults
4846
}
@@ -51,40 +49,21 @@ class BackendUtils(
5149
val remainingPartitions =
5250
contexts.indices.filterNot(k => cachedResults.containsOrdered[Int](k, _ < _, _._2))
5351

54-
val backend = Backend.get
5552
val mod = getModule(modID)
5653
val t = System.nanoTime()
5754
val (failureOpt, successes) =
5855
remainingPartitions match {
5956
case Seq() =>
6057
(None, IndexedSeq.empty)
61-
case Seq(k) if backend.canExecuteParallelTasksOnDriver =>
62-
try
63-
using(new LocalTaskContext(k, 0)) { htc =>
64-
using(htc.getRegionPool().getRegion()) { r =>
65-
val f = mod(theDriverHailClassLoader, fs, htc, r)
66-
val res = retryTransientErrors(f(r, contexts(k), globals))
67-
(None, FastSeq(res -> k))
68-
}
69-
}
70-
catch {
71-
case NonFatal(ex) =>
72-
(Some(ex), IndexedSeq.empty)
73-
}
58+
7459
case partitions =>
75-
val globalsBC = backend.broadcast(globals)
76-
val fsConfigBC = backend.broadcast(fs.getConfiguration())
77-
backend.parallelizeAndComputeWithIndex(
78-
backendContext,
79-
fs,
60+
ctx.collectDistributedArray(
61+
globals,
8062
contexts,
8163
stageName,
8264
tsd,
8365
Some(partitions),
84-
) { (ctx, htc, theHailClassLoader, fs) =>
85-
val fsConfig = fsConfigBC.value
86-
val gs = globalsBC.value
87-
fs.setConfiguration(fsConfig)
66+
) { (gs, ctx, htc, theHailClassLoader, fs) =>
8867
htc.getRegionPool().scopedRegion { region =>
8968
mod(theHailClassLoader, fs, htc, region)(region, ctx, gs)
9069
}
@@ -102,7 +81,7 @@ class BackendUtils(
10281
_._2 < _._2,
10382
)
10483

105-
semhash.foreach(s => backendContext.executionCache.put(s, results))
84+
semhash.foreach(s => ctx.executionCache.put(s, results))
10685
failureOpt.foreach(throw _)
10786

10887
results.map(_._1).toArray

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,7 @@ class ExecuteContext(
148148
f: (HailClassLoader, FS, HailTaskContext, Region) => T
149149
)(implicit E: Enclosing
150150
): T =
151-
using(new LocalTaskContext(0, 0)) { tc =>
152-
Backend.set(backend)
153-
try time(f(theHailClassLoader, fs, tc, r))
154-
finally Backend.set(null)
155-
}
151+
using(new LocalTaskContext(0, 0))(tc => time(f(theHailClassLoader, fs, tc, r)))
156152

157153
def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String =
158154
tempFileManager.newTmpPath(if (local) localTmpdir else tmpdir, prefix, extension)

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
package is.hail.backend.local
22

3-
import is.hail.CancellingExecutorService
4-
import is.hail.asm4s._
53
import is.hail.backend._
4+
import is.hail.backend.Backend.PartitionFn
65
import is.hail.expr.Validate
76
import is.hail.expr.ir._
87
import is.hail.expr.ir.analyses.SemanticHash
98
import is.hail.expr.ir.lowering._
10-
import is.hail.io.fs._
119
import is.hail.types._
1210
import is.hail.types.physical.PTuple
1311
import is.hail.utils._
1412

13+
import scala.collection.mutable
1514
import scala.reflect.ClassTag
15+
import scala.util.control.NonFatal
1616

1717
import java.io.PrintWriter
1818

1919
import com.fasterxml.jackson.core.StreamReadConstraints
20-
import com.google.common.util.concurrent.MoreExecutors
2120

2221
class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable
2322

@@ -53,9 +52,6 @@ object LocalBackend extends Backend {
5352
this
5453
}
5554

56-
private case class Context(hcl: HailClassLoader, override val executionCache: ExecutionCache)
57-
extends BackendContext
58-
5955
def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value)
6056

6157
private[this] var stageIdx: Int = 0
@@ -67,32 +63,54 @@ object LocalBackend extends Backend {
6763
current
6864
}
6965

70-
override def parallelizeAndComputeWithIndex(
71-
ctx: BackendContext,
72-
fs: FS,
73-
contexts: IndexedSeq[Array[Byte]],
74-
stageIdentifier: String,
75-
dependency: Option[TableStageDependency],
76-
partitions: Option[IndexedSeq[Int]],
77-
)(
78-
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
79-
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {
80-
81-
val stageId = nextStageId()
82-
val hcl = ctx.asInstanceOf[Context].hcl
83-
runAllKeepFirstError(new CancellingExecutorService(MoreExecutors.newDirectExecutorService())) {
84-
partitions.getOrElse(contexts.indices).map { i =>
85-
(
86-
() => using(new LocalTaskContext(i, stageId))(f(contexts(i), _, hcl, fs)),
87-
i,
88-
)
66+
override def driverContext(ctx: ExecuteContext): DriverContext = {
67+
new DriverContext {
68+
69+
override val executionCache: ExecutionCache =
70+
ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.localTmpdir)
71+
72+
override def collectDistributedArray(
73+
globals: Array[Byte],
74+
contexts: IndexedSeq[Array[Byte]],
75+
stageIdentifier: String,
76+
dependency: Option[TableStageDependency],
77+
partitions: Option[IndexedSeq[Int]],
78+
)(
79+
f: PartitionFn
80+
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = {
81+
82+
val todo: IndexedSeq[Int] =
83+
partitions.getOrElse(contexts.indices)
84+
85+
val results: mutable.ArrayBuffer[(Array[Byte], Int)] =
86+
new mutable.ArrayBuffer(todo.length)
87+
88+
var failure: Option[Throwable] =
89+
None
90+
91+
val stageId = nextStageId()
92+
93+
try
94+
for (idx <- todo) {
95+
using(new LocalTaskContext(idx, stageId)) { tx =>
96+
results += {
97+
(
98+
f(globals, contexts(idx), tx, ctx.theHailClassLoader, ctx.fs),
99+
idx,
100+
)
101+
}
102+
}
103+
}
104+
catch {
105+
case NonFatal(t) =>
106+
failure = Some(t)
107+
}
108+
109+
(failure, results)
89110
}
90111
}
91112
}
92113

93-
override def backendContext(ctx: ExecuteContext): BackendContext =
94-
Context(ctx.theHailClassLoader, ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir))
95-
96114
def defaultParallelism: Int = 1
97115

98116
def close(): Unit =

0 commit comments

Comments
 (0)