Skip to content

Commit 3a6917a

Browse files
authored
[query] make local and remote tmp settable on backend (#14748)
Refactors temporary directory handling in Hail's backend system by introducing abstract properties for local and remote temporary directories. This moves the temporary directory management from the HailContext into the backend implementations. This change has no security impact
1 parent b10367e commit 3a6917a

File tree

5 files changed

+80
-27
lines changed

5 files changed

+80
-27
lines changed

hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,15 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
7171
synchronized { tmpdir = tmp }
7272

7373
def pySetLocalTmp(tmp: String): Unit =
74-
synchronized { localTmpdir = tmp }
74+
synchronized {
75+
localTmpdir = tmp
76+
backend match {
77+
case s: SparkBackend =>
78+
void(s.sc.getConf.set("spark.local.dir", tmp))
79+
case _ =>
80+
()
81+
}
82+
}
7583

7684
def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
7785
synchronized {

hail/python/hail/backend/backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,23 @@ def get_flags(self, *flags) -> Mapping[str, str]:
393393
@abc.abstractmethod
394394
def requires_lowering(self):
395395
pass
396+
397+
@property
398+
@abc.abstractmethod
399+
def local_tmpdir(self) -> str:
400+
pass
401+
402+
@local_tmpdir.setter
403+
@abc.abstractmethod
404+
def local_tmpdir(self, dir: str) -> None:
405+
pass
406+
407+
@property
408+
@abc.abstractmethod
409+
def remote_tmpdir(self) -> str:
410+
pass
411+
412+
@remote_tmpdir.setter
413+
@abc.abstractmethod
414+
def remote_tmpdir(self, dir: str) -> None:
415+
pass

hail/python/hail/backend/py4j_backend.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def decode_bytearray(encoded):
198198
self._jhc = jhc
199199

200200
self._jbackend = self._hail_package.backend.driver.Py4JQueryDriver(jbackend)
201-
self._jbackend.pySetLocalTmp(tmpdir)
202-
self._jbackend.pySetRemoteTmp(remote_tmpdir)
201+
self.local_tmpdir = tmpdir
202+
self.remote_tmpdir = remote_tmpdir
203203

204204
self._jhttp_server = self._jbackend.pyHttpServer()
205205

@@ -341,3 +341,21 @@ def stop(self):
341341
self._jhc = None
342342
uninstall_exception_handler()
343343
super().stop()
344+
345+
@property
346+
def local_tmpdir(self) -> str:
347+
return self._local_tmpdir
348+
349+
@local_tmpdir.setter
350+
def local_tmpdir(self, tmpdir: str) -> None:
351+
self._local_tmpdir = tmpdir
352+
self._jbackend.pySetLocalTmp(tmpdir)
353+
354+
@property
355+
def remote_tmpdir(self) -> str:
356+
return self._remote_tmpdir
357+
358+
@remote_tmpdir.setter
359+
def remote_tmpdir(self, tmpdir: str) -> None:
360+
self._remote_tmpdir = tmpdir
361+
self._jbackend.pySetRemoteTmp(tmpdir)

hail/python/hail/backend/service_backend.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from contextlib import AsyncExitStack
77
from dataclasses import dataclass
8-
from typing import Any, Awaitable, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
8+
from typing import Any, Awaitable, Dict, List, Mapping, NoReturn, Optional, Set, Tuple, TypeVar, Union
99

1010
import orjson
1111

@@ -230,7 +230,7 @@ def __init__(
230230
self._batch = batch
231231
self._job_group_was_submitted: bool = False
232232
self.disable_progress_bar = disable_progress_bar
233-
self.remote_tmpdir = remote_tmpdir
233+
self._remote_tmpdir = remote_tmpdir
234234
self.flags: Dict[str, str] = {}
235235
self._registered_ir_function_names: Set[str] = set()
236236
self.driver_cores = driver_cores
@@ -502,3 +502,19 @@ def get_flags(self, *flags: str) -> Mapping[str, str]:
502502
@property
503503
def requires_lowering(self):
504504
return True
505+
506+
@property
507+
def local_tmpdir(self) -> NoReturn:
508+
raise AttributeError('local tmp folders are not supported on the batch backend')
509+
510+
@local_tmpdir.setter
511+
def local_tmpdir(self, tmpdir: str) -> NoReturn:
512+
raise AttributeError('local tmp folders are not supported on the batch backend')
513+
514+
@property
515+
def remote_tmpdir(self) -> str:
516+
return self._remote_tmpdir
517+
518+
@remote_tmpdir.setter
519+
def remote_tmpdir(self, tmpdir: str) -> None:
520+
self._remote_tmpdir = tmpdir

hail/python/hail/context.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def create(
6666
log: str,
6767
quiet: bool,
6868
append: bool,
69-
tmpdir: str,
70-
local_tmpdir: str,
7169
default_reference: str,
7270
global_seed: Optional[int],
7371
backend: Backend,
@@ -76,25 +74,17 @@ def create(
7674
log=log,
7775
quiet=quiet,
7876
append=append,
79-
tmpdir=tmpdir,
80-
local_tmpdir=local_tmpdir,
8177
global_seed=global_seed,
8278
backend=backend,
8379
)
8480
hc.initialize_references(default_reference)
8581
return hc
8682

87-
@typecheck_method(
88-
log=str, quiet=bool, append=bool, tmpdir=str, local_tmpdir=str, global_seed=nullable(int), backend=Backend
89-
)
90-
def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backend):
83+
@typecheck_method(log=str, quiet=bool, append=bool, global_seed=nullable(int), backend=Backend)
84+
def __init__(self, log, quiet, append, global_seed, backend: Backend):
9185
assert not Env._hc
9286

9387
self._log = log
94-
95-
self._tmpdir = tmpdir
96-
self._local_tmpdir = local_tmpdir
97-
9888
self._backend = backend
9989

10090
self._warn_cols_order = True
@@ -136,6 +126,14 @@ def initialize_references(self, default_reference):
136126
else:
137127
self._default_ref = ReferenceGenome.read(default_reference)
138128

129+
@property
130+
def _tmpdir(self) -> str:
131+
return self._backend.remote_tmpdir
132+
133+
@property
134+
def _local_tmpdir(self) -> str:
135+
return self._backend.local_tmpdir
136+
139137
@property
140138
def default_reference(self) -> ReferenceGenome:
141139
assert self._default_ref is not None, '_default_ref should have been initialized in HailContext.create'
@@ -376,7 +374,6 @@ def init(
376374
quiet=quiet,
377375
append=append,
378376
tmpdir=tmp_dir,
379-
local_tmpdir=local_tmpdir,
380377
default_reference=default_reference,
381378
global_seed=global_seed,
382379
driver_cores=driver_cores,
@@ -503,7 +500,7 @@ def init_spark(
503500
if not backend.fs.exists(tmpdir):
504501
backend.fs.mkdir(tmpdir)
505502

506-
HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
503+
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
507504
if not quiet:
508505
connect_logger(backend._utils_package_object, 'localhost', 12888)
509506

@@ -515,7 +512,6 @@ def init_spark(
515512
quiet=bool,
516513
append=bool,
517514
tmpdir=nullable(str),
518-
local_tmpdir=nullable(str),
519515
default_reference=enumeration(*BUILTIN_REFERENCES),
520516
global_seed=nullable(int),
521517
disable_progress_bar=nullable(bool),
@@ -538,7 +534,6 @@ async def init_batch(
538534
quiet: bool = False,
539535
append: bool = False,
540536
tmpdir: Optional[str] = None,
541-
local_tmpdir: Optional[str] = None,
542537
default_reference: str = 'GRCh37',
543538
global_seed: Optional[int] = None,
544539
disable_progress_bar: Optional[bool] = None,
@@ -573,11 +568,7 @@ async def init_batch(
573568
)
574569

575570
log = _get_log(log)
576-
if tmpdir is None:
577-
tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string())
578-
local_tmpdir = _get_local_tmpdir(local_tmpdir)
579-
580-
HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
571+
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
581572

582573

583574
@typecheck(
@@ -629,7 +620,7 @@ def init_local(
629620
if not backend.fs.exists(tmpdir):
630621
backend.fs.mkdir(tmpdir)
631622

632-
HailContext.create(log, quiet, append, tmpdir, tmpdir, default_reference, global_seed, backend)
623+
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
633624
if not quiet:
634625
connect_logger(backend._utils_package_object, 'localhost', 12888)
635626

0 commit comments

Comments
 (0)