Skip to content

Commit 2e67e7b

Browse files
committed
[query] remove branchingFactor from HailContext`
1 parent 969d3a0 commit 2e67e7b

File tree

10 files changed

+76
-29
lines changed

10 files changed

+76
-29
lines changed

hail/hail/src/is/hail/HailContext.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,13 @@ object HailContext {
9292
}
9393
}
9494

95-
def getOrCreate(backend: Backend, branchingFactor: Int = 50): HailContext = {
96-
if (theContext == null)
97-
return HailContext(backend, branchingFactor)
98-
99-
if (theContext.branchingFactor != branchingFactor)
100-
warn(
101-
s"Requested branchingFactor $branchingFactor, but already initialized to ${theContext.branchingFactor}. Ignoring requested setting."
102-
)
103-
104-
theContext
105-
}
95+
def getOrCreate(backend: Backend): HailContext =
96+
synchronized {
97+
if (isInitialized) theContext
98+
else HailContext(backend)
99+
}
106100

107-
def apply(backend: Backend, branchingFactor: Int = 50): HailContext = synchronized {
101+
def apply(backend: Backend): HailContext = synchronized {
108102
require(theContext == null)
109103
checkJavaVersion()
110104

@@ -122,7 +116,7 @@ object HailContext {
122116
)
123117
}
124118

125-
theContext = new HailContext(backend, branchingFactor)
119+
theContext = new HailContext(backend)
126120

127121
info(s"Running Hail version ${theContext.version}")
128122

@@ -164,8 +158,7 @@ object HailContext {
164158
}
165159

166160
class HailContext private (
167-
var backend: Backend,
168-
val branchingFactor: Int,
161+
var backend: Backend
169162
) {
170163
def stop(): Unit = HailContext.stop()
171164

hail/hail/src/is/hail/HailFeatureFlags.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ package is.hail
22

33
import is.hail.backend.ExecutionCache
44
import is.hail.backend.spark.SparkBackend
5-
import is.hail.expr.ir.Optimize
5+
import is.hail.expr.ir.{Optimize, agg}
66
import is.hail.io.fs.RequesterPaysConfig
77
import is.hail.types.encoded.EType
88
import is.hail.utils._
99

1010
import scala.collection.mutable
11-
1211
import org.json4s.JsonAST.{JArray, JObject, JString}
1312

1413
object HailFeatureFlags {
@@ -48,6 +47,7 @@ object HailFeatureFlags {
4847
),
4948
(Optimize.Flags.Optimize, "HAIL_QUERY_OPTIMIZE" -> "1"),
5049
(Optimize.Flags.MaxOptimizerIterations, "HAIL_OPTIMIZER_ITERATIONS" -> null),
50+
(agg.Flags.BranchFactor, "HAIL_BRANCH_FACTOR" -> null),
5151
)
5252

5353
def fromEnv(m: Map[String, String] = sys.env): HailFeatureFlags =
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package is.hail.expr.ir
2+
3+
import is.hail.backend.ExecuteContext
4+
import is.hail.utils.HailException
5+
6+
package agg {
7+
object Flags {
8+
val BranchFactor = "branch_factor"
9+
}
10+
}
11+
12+
package object agg {
13+
val DefaultBranchFactor: Int = 50
14+
15+
def branchFactor(ctx: ExecuteContext): Int = {
16+
val factor =
17+
ctx.flags.lookup(Flags.BranchFactor).map { s =>
18+
val factor =
19+
try s.toInt
20+
catch {
21+
case _: NumberFormatException =>
22+
throw new HailException(
23+
f"'${Flags.BranchFactor}' must be a positive integer, got '$s'."
24+
)
25+
}
26+
27+
if (factor < 0)
28+
throw new HailException(
29+
f"'${Flags.BranchFactor}' must be greater than 0, got '$factor'."
30+
)
31+
32+
factor
33+
}
34+
35+
factor.getOrElse(DefaultBranchFactor)
36+
}
37+
}

hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package is.hail.expr.ir.lowering
22

3-
import is.hail.HailContext
43
import is.hail.backend.ExecuteContext
54
import is.hail.expr.ir.{agg, TableNativeWriter, _}
65
import is.hail.expr.ir.defs._
@@ -850,7 +849,7 @@ object LowerTableIR {
850849
InitFromSerializedValue(i, GetTupleElement(initStateRef, i), agg.state)
851850
})
852851

853-
val branchFactor = HailContext.get.branchingFactor
852+
val branchFactor = agg.branchFactor(ctx)
854853
val useTreeAggregate = aggs.shouldTreeAggregate && branchFactor < lc.numPartitions
855854
val isCommutative = aggs.isCommutative
856855
log.info(s"Aggregate: useTreeAggregate=$useTreeAggregate")
@@ -1681,7 +1680,7 @@ object LowerTableIR {
16811680
val initFromSerializedStates = Begin(aggs.aggs.zipWithIndex.map { case (agg, i) =>
16821681
InitFromSerializedValue(i, GetTupleElement(initStateRef, i), agg.state)
16831682
})
1684-
val branchFactor = HailContext.get.branchingFactor
1683+
val branchFactor = agg.branchFactor(ctx)
16851684
val big = aggs.shouldTreeAggregate && branchFactor < lc.numPartitions
16861685
val (partitionPrefixSumValues, transformPrefixSum): (IR, IR => IR) = if (big) {
16871686
val tmpDir = ctx.createTmpPath("aggregate_intermediates/")

hail/hail/src/is/hail/rvd/RVD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import is.hail.annotations._
55
import is.hail.asm4s.{theHailClassLoaderForSparkWorkers, HailClassLoader}
66
import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext}
77
import is.hail.backend.spark.{SparkBackend, SparkTaskContext}
8-
import is.hail.expr.ir.InferPType
8+
import is.hail.expr.ir.{agg, InferPType}
99
import is.hail.expr.ir.PruneDeadFields.isSupertype
1010
import is.hail.io._
1111
import is.hail.io.index.IndexWriter
@@ -722,7 +722,7 @@ class RVD(
722722
}
723723

724724
if (tree) {
725-
val depth = treeAggDepth(getNumPartitions, HailContext.get.branchingFactor)
725+
val depth = treeAggDepth(getNumPartitions, agg.branchFactor(execCtx))
726726
val scale = math.max(
727727
math.ceil(math.pow(getNumPartitions, 1.0 / depth)).toInt,
728728
2,

hail/python/hail/backend/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class Backend(abc.ABC):
197197
"write_ir_files": ("HAIL_WRITE_IR_FILES", None),
198198
"optimize": ("HAIL_QUERY_OPTIMIZE", "1"),
199199
"max_optimizer_iterations": ("HAIL_OPTIMIZER_ITERATIONS", None),
200+
"branch_factor": ("HAIL_BRANCH_FACTOR", None),
200201
}
201202

202203
def _valid_flags(self) -> AbstractSet[str]:

hail/python/hail/backend/local_backend.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55
from contextlib import ExitStack
6-
from typing import List, Optional, Tuple, Union
6+
from typing import Dict, List, Optional, Tuple, Union
77

88
from py4j.java_gateway import GatewayParameters, JavaGateway, launch_gateway
99

@@ -78,7 +78,7 @@ def __init__(
7878
append,
7979
skip_logging_configuration,
8080
)
81-
jhc = hail_package.HailContext.apply(jbackend, branching_factor)
81+
jhc = hail_package.HailContext.apply(jbackend)
8282

8383
super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
8484
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration
@@ -87,7 +87,12 @@ def __init__(
8787
)
8888

8989
self._logger = None
90-
self._initialize_flags({})
90+
91+
flags: Dict[str, str] = {}
92+
if branching_factor is not None:
93+
flags['branch_factor'] = str(branching_factor)
94+
95+
self._initialize_flags(flags)
9196

9297
def validate_file(self, uri: str) -> None:
9398
async_to_blocking(validate_file(uri, self._fs.afs))

hail/python/hail/backend/service_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ async def create(
130130
regions: Optional[List[str]] = None,
131131
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
132132
gcs_bucket_allow_list: Optional[List[str]] = None,
133+
branching_factor: Optional[int] = None,
133134
):
134135
async_exit_stack = AsyncExitStack()
135136
billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None)
@@ -183,6 +184,9 @@ async def create(
183184
disable_progress_bar = len(disable_progress_bar_str) > 0
184185

185186
flags = flags or {}
187+
if branching_factor is not None:
188+
flags['branch_factor'] = str(branching_factor)
189+
186190
if 'gcs_requester_pays_project' in flags or 'gcs_requester_pays_buckets' in flags:
187191
raise ValueError(
188192
'Specify neither gcs_requester_pays_project nor gcs_requester_'

hail/python/hail/backend/spark_backend.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import sys
3-
from typing import Any, Optional
3+
from typing import Any, Dict, Optional
44

55
import orjson
66
import pyspark
@@ -117,7 +117,7 @@ def __init__(
117117
skip_logging_configuration,
118118
min_block_size,
119119
)
120-
jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor)
120+
jhc = hail_package.HailContext.getOrCreate(jbackend)
121121
else:
122122
jbackend = hail_package.backend.spark.SparkBackend.apply(
123123
jsc,
@@ -130,7 +130,7 @@ def __init__(
130130
skip_logging_configuration,
131131
min_block_size,
132132
)
133-
jhc = hail_package.HailContext.apply(jbackend, branching_factor)
133+
jhc = hail_package.HailContext.apply(jbackend)
134134

135135
self._jsc = jbackend.sc()
136136
if sc:
@@ -152,7 +152,11 @@ def __init__(
152152

153153
jbackend.pyStartProgressBar()
154154

155-
self._initialize_flags({})
155+
flags: Dict[str, str] = {}
156+
if branching_factor is not None:
157+
flags['branch_factor'] = str(branching_factor)
158+
159+
self._initialize_flags(flags)
156160

157161
self._router_async_fs = RouterAsyncFS(
158162
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config}

hail/python/hail/context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def init(
386386
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
387387
regions=regions,
388388
gcs_bucket_allow_list=gcs_bucket_allow_list,
389+
branching_factor=branching_factor,
389390
)
390391
)
391392
if backend == 'spark':
@@ -522,6 +523,7 @@ def init_spark(
522523
gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))),
523524
regions=nullable(sequenceof(str)),
524525
gcs_bucket_allow_list=nullable(sequenceof(str)),
526+
braching_factor=nullable(int),
525527
)
526528
async def init_batch(
527529
*,
@@ -545,6 +547,7 @@ async def init_batch(
545547
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
546548
regions: Optional[List[str]] = None,
547549
gcs_bucket_allow_list: Optional[List[str]] = None,
550+
branching_factor: Optional[int] = None,
548551
):
549552
from hail.backend.service_backend import ServiceBackend
550553

@@ -563,6 +566,7 @@ async def init_batch(
563566
regions=regions,
564567
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
565568
gcs_bucket_allow_list=gcs_bucket_allow_list,
569+
branching_factor=branching_factor,
566570
)
567571

568572
log = _get_log(log)

0 commit comments

Comments
 (0)