77import pyspark .sql
88
99from hail .expr .table_type import ttable
10- from hail .fs .hadoop_fs import HadoopFS
1110from hail .ir import BaseIR
1211from hail .ir .renderer import CSERenderer
1312from hail .table import Table
1413from hail .utils import copy_log
14+ from hailtop .aiocloud .aiogoogle import GCSRequesterPaysConfiguration
1515from hailtop .aiotools .router_fs import RouterAsyncFS
1616from hailtop .aiotools .validators import validate_file
1717from hailtop .utils import async_to_blocking
@@ -47,12 +47,9 @@ def __init__(
4747 skip_logging_configuration ,
4848 optimizer_iterations ,
4949 * ,
50- gcs_requester_pays_project : Optional [str ] = None ,
51- gcs_requester_pays_buckets : Optional [str ] = None ,
50+ gcs_requester_pays_config : Optional [GCSRequesterPaysConfiguration ] = None ,
5251 copy_log_on_error : bool = False ,
5352 ):
54- assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None
55-
5653 try :
5754 local_jar_info = local_jar_information ()
5855 except ValueError :
@@ -120,10 +117,6 @@ def __init__(
120117 append ,
121118 skip_logging_configuration ,
122119 min_block_size ,
123- tmpdir ,
124- local_tmpdir ,
125- gcs_requester_pays_project ,
126- gcs_requester_pays_buckets ,
127120 )
128121 jhc = hail_package .HailContext .getOrCreate (jbackend , branching_factor , optimizer_iterations )
129122 else :
@@ -137,10 +130,6 @@ def __init__(
137130 append ,
138131 skip_logging_configuration ,
139132 min_block_size ,
140- tmpdir ,
141- local_tmpdir ,
142- gcs_requester_pays_project ,
143- gcs_requester_pays_buckets ,
144133 )
145134 jhc = hail_package .HailContext .apply (jbackend , branching_factor , optimizer_iterations )
146135
@@ -149,12 +138,12 @@ def __init__(
149138 self .sc = sc
150139 else :
151140 self .sc = pyspark .SparkContext (gateway = self ._gateway , jsc = jvm .JavaSparkContext (self ._jsc ))
152- self ._jspark_session = jbackend .sparkSession ()
141+ self ._jspark_session = jbackend .sparkSession (). apply ()
153142 self ._spark_session = pyspark .sql .SparkSession (self .sc , self ._jspark_session )
154143
155- super (SparkBackend , self ).__init__ (jvm , jbackend , jhc )
144+ super ().__init__ (jvm , jbackend , jhc , local_tmpdir , tmpdir )
145+ self .gcs_requester_pays_configuration = gcs_requester_pays_config
156146
157- self ._fs = None
158147 self ._logger = None
159148
160149 if not quiet :
@@ -167,7 +156,7 @@ def __init__(
167156 self ._initialize_flags ({})
168157
169158 self ._router_async_fs = RouterAsyncFS (
170- gcs_kwargs = {"gcs_requester_pays_configuration" : gcs_requester_pays_project }
159+ gcs_kwargs = {"gcs_requester_pays_configuration" : gcs_requester_pays_config }
171160 )
172161
173162 self ._tmpdir = tmpdir
@@ -181,12 +170,6 @@ def stop(self):
181170 self .sc .stop ()
182171 self .sc = None
183172
184- @property
185- def fs (self ):
186- if self ._fs is None :
187- self ._fs = HadoopFS (self ._utils_package_object , self ._jbackend .fs ())
188- return self ._fs
189-
190173 def from_spark (self , df , key ):
191174 result_tuple = self ._jbackend .pyFromDF (df ._jdf , key )
192175 tir_id , type_json = result_tuple ._1 (), result_tuple ._2 ()
0 commit comments