Skip to content

Commit 2957928

Browse files
committed
wire up spark, local and py4j backends
1 parent 1bf70ac commit 2957928

File tree

8 files changed

+59
-56
lines changed

8 files changed

+59
-56
lines changed

hail/python/hail/backend/local_backend.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,23 +81,14 @@ def __init__(
8181
)
8282
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)
8383

84-
super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc)
84+
super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
85+
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration
8586
self._fs = self._exit_stack.enter_context(
8687
RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration})
8788
)
8889

8990
self._logger = None
90-
91-
flags = {}
92-
if gcs_requester_pays_configuration is not None:
93-
if isinstance(gcs_requester_pays_configuration, str):
94-
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration
95-
else:
96-
assert isinstance(gcs_requester_pays_configuration, tuple)
97-
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0]
98-
flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1])
99-
100-
self._initialize_flags(flags)
91+
self._initialize_flags({})
10192

10293
def validate_file(self, uri: str) -> None:
10394
async_to_blocking(validate_file(uri, self._fs.afs))

hail/python/hail/backend/py4j_backend.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313

1414
import hail
1515
from hail.expr import construct_expr
16+
from hail.fs.hadoop_fs import HadoopFS
1617
from hail.ir import JavaIR
1718
from hail.utils.java import Env, FatalError, scala_package_object
19+
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration
1820

1921
from ..hail_logging import Logger
2022
from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet
@@ -193,11 +195,17 @@ def decode_bytearray(encoded):
193195
self._utils_package_object = scala_package_object(self._hail_package.utils)
194196
self._jhc = jhc
195197

196-
self._jbackend = self._hail_package.backend.api.P4jBackendApi(jbackend)
198+
self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend)
199+
self._jbackend.pySetLocalTmp(tmpdir)
200+
self._jbackend.pySetRemoteTmp(remote_tmpdir)
201+
197202
self._jhttp_server = self._jbackend.pyHttpServer()
198-
self._backend_server_port: int = self._jbackend.HttpServer.port()
203+
self._backend_server_port: int = self._jhttp_server.port()
199204
self._requests_session = requests.Session()
200205

206+
self._gcs_requester_pays_config = None
207+
self._fs = None
208+
201209
# This has to go after creating the SparkSession. Unclear why.
202210
# Maybe it does its own patch?
203211
install_exception_handler()
@@ -221,6 +229,23 @@ def hail_package(self):
221229
def utils_package_object(self):
222230
return self._utils_package_object
223231

232+
@property
233+
def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]:
234+
return self._gcs_requester_pays_config
235+
236+
@gcs_requester_pays_configuration.setter
237+
def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]):
238+
self._gcs_requester_pays_config = config
239+
project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config
240+
self._jbackend.pySetGcsRequesterPaysConfig(project, buckets)
241+
self._fs = None # stale
242+
243+
@property
244+
def fs(self):
245+
if self._fs is None:
246+
self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs())
247+
return self._fs
248+
224249
@property
225250
def logger(self):
226251
if self._logger is None:

hail/python/hail/backend/spark_backend.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import pyspark.sql
88

99
from hail.expr.table_type import ttable
10-
from hail.fs.hadoop_fs import HadoopFS
1110
from hail.ir import BaseIR
1211
from hail.ir.renderer import CSERenderer
1312
from hail.table import Table
1413
from hail.utils import copy_log
14+
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration
1515
from hailtop.aiotools.router_fs import RouterAsyncFS
1616
from hailtop.aiotools.validators import validate_file
1717
from hailtop.utils import async_to_blocking
@@ -47,12 +47,9 @@ def __init__(
4747
skip_logging_configuration,
4848
optimizer_iterations,
4949
*,
50-
gcs_requester_pays_project: Optional[str] = None,
51-
gcs_requester_pays_buckets: Optional[str] = None,
50+
gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None,
5251
copy_log_on_error: bool = False,
5352
):
54-
assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None
55-
5653
try:
5754
local_jar_info = local_jar_information()
5855
except ValueError:
@@ -120,10 +117,6 @@ def __init__(
120117
append,
121118
skip_logging_configuration,
122119
min_block_size,
123-
tmpdir,
124-
local_tmpdir,
125-
gcs_requester_pays_project,
126-
gcs_requester_pays_buckets,
127120
)
128121
jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations)
129122
else:
@@ -137,10 +130,6 @@ def __init__(
137130
append,
138131
skip_logging_configuration,
139132
min_block_size,
140-
tmpdir,
141-
local_tmpdir,
142-
gcs_requester_pays_project,
143-
gcs_requester_pays_buckets,
144133
)
145134
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)
146135

@@ -149,12 +138,12 @@ def __init__(
149138
self.sc = sc
150139
else:
151140
self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc))
152-
self._jspark_session = jbackend.sparkSession()
141+
self._jspark_session = jbackend.sparkSession().apply()
153142
self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session)
154143

155-
super(SparkBackend, self).__init__(jvm, jbackend, jhc)
144+
super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir)
145+
self.gcs_requester_pays_configuration = gcs_requester_pays_config
156146

157-
self._fs = None
158147
self._logger = None
159148

160149
if not quiet:
@@ -167,7 +156,7 @@ def __init__(
167156
self._initialize_flags({})
168157

169158
self._router_async_fs = RouterAsyncFS(
170-
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project}
159+
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config}
171160
)
172161

173162
self._tmpdir = tmpdir
@@ -181,12 +170,6 @@ def stop(self):
181170
self.sc.stop()
182171
self.sc = None
183172

184-
@property
185-
def fs(self):
186-
if self._fs is None:
187-
self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs())
188-
return self._fs
189-
190173
def from_spark(self, df, key):
191174
result_tuple = self._jbackend.pyFromDF(df._jdf, key)
192175
tir_id, type_json = result_tuple._1(), result_tuple._2()

hail/python/hail/context.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -474,14 +474,10 @@ def init_spark(
474474
optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3)
475475

476476
app_name = app_name or 'Hail'
477-
(
478-
gcs_requester_pays_project,
479-
gcs_requester_pays_buckets,
480-
) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style(
481-
get_gcs_requester_pays_configuration(
482-
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
483-
)
477+
gcs_requester_pays_configuration = get_gcs_requester_pays_configuration(
478+
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
484479
)
480+
485481
backend = SparkBackend(
486482
idempotent,
487483
sc,
@@ -498,8 +494,7 @@ def init_spark(
498494
local_tmpdir,
499495
skip_logging_configuration,
500496
optimizer_iterations,
501-
gcs_requester_pays_project=gcs_requester_pays_project,
502-
gcs_requester_pays_buckets=gcs_requester_pays_buckets,
497+
gcs_requester_pays_config=gcs_requester_pays_configuration,
503498
copy_log_on_error=copy_log_on_error,
504499
)
505500
if not backend.fs.exists(tmpdir):

hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import is.hail.backend._
66
import is.hail.backend.caching.BlockMatrixCache
77
import is.hail.backend.spark.SparkBackend
88
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
9-
import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue}
9+
import is.hail.expr.ir.{
10+
BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser,
11+
Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue,
12+
}
1013
import is.hail.expr.ir.IRParser.parseType
1114
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1215
import is.hail.expr.ir.functions.IRFunctionRegistry
@@ -21,14 +24,18 @@ import is.hail.utils._
2124
import is.hail.utils.ExecutionTimer.Timings
2225
import is.hail.variant.ReferenceGenome
2326

27+
import scala.annotation.nowarn
2428
import scala.collection.mutable
2529
import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter}
30+
2631
import java.io.Closeable
2732
import java.net.InetSocketAddress
2833
import java.util
2934
import java.util.concurrent._
35+
3036
import com.google.api.client.http.HttpStatusCodes
3137
import com.sun.net.httpserver.{HttpExchange, HttpServer}
38+
import javax.annotation.Nullable
3239
import org.apache.hadoop
3340
import org.apache.hadoop.conf.Configuration
3441
import org.apache.spark.sql.DataFrame
@@ -37,9 +44,6 @@ import org.json4s._
3744
import org.json4s.jackson.{JsonMethods, Serialization}
3845
import sourcecode.Enclosing
3946

40-
import javax.annotation.Nullable
41-
import scala.annotation.nowarn
42-
4347
final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling {
4448

4549
private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
@@ -74,6 +78,9 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
7478
manager.close()
7579
}
7680

81+
def pyFs: FS =
82+
tmpFileManager.getFs
83+
7784
def pyGetFlag(name: String): String =
7885
flags.get(name)
7986

@@ -83,13 +90,14 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
8390
def pyAvailableFlags: java.util.ArrayList[String] =
8491
flags.available
8592

86-
def pySetTmpdir(tmp: String): Unit =
93+
def pySetRemoteTmp(tmp: String): Unit =
8794
tmpdir = tmp
8895

8996
def pySetLocalTmp(tmp: String): Unit =
9097
localTmpdir = tmp
9198

92-
def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = {
99+
def pySetGcsRequesterPaysConfig(@Nullable project: String, @Nullable buckets: util.List[String])
100+
: Unit = {
93101
val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)
94102

95103
val rpConfig: Option[RequesterPaysConfig] =

hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS}
1111
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1212
import is.hail.services._
1313
import is.hail.types.virtual.Kinds
14-
import is.hail.utils.{toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging}
14+
import is.hail.utils.{
15+
toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging,
16+
}
1517
import is.hail.utils.ExecutionTimer.Timings
1618
import is.hail.variant.ReferenceGenome
1719

hail/src/main/scala/is/hail/utils/package.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ package utils {
9292
}
9393
}
9494

95-
9695
class Lazy[A] private[utils] (f: => A) {
9796
private[this] var option: Option[A] = None
9897

hail/src/test/scala/is/hail/HailSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class HailSuite extends TestNGSuite with TestUtils {
7070
var pool: RegionPool = _
7171
private[this] var ctx_ : ExecuteContext = _
7272

73-
def backend: Backend = ctx.backend
74-
def sc: SparkContext = backend.asSpark.sc
73+
def backend: Backend = hc.backend
74+
def sc: SparkContext = hc.backend.asSpark.sc
7575
def timer: ExecutionTimer = ctx.timer
7676
def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader
7777
override def ctx: ExecuteContext = ctx_

0 commit comments

Comments
 (0)