Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion hail/hail/src/is/hail/backend/driver/Py4JQueryDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ final class Py4JQueryDriver(backend: Backend) extends Closeable {
synchronized { tmpdir = tmp }

def pySetLocalTmp(tmp: String): Unit =
synchronized { localTmpdir = tmp }
synchronized {
localTmpdir = tmp
backend match {
case s: SparkBackend =>
void(s.sc.getConf.set("spark.local.dir", tmp))
case _ =>
()
}
}

def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
synchronized {
Expand Down
20 changes: 20 additions & 0 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,23 @@ def get_flags(self, *flags) -> Mapping[str, str]:
@abc.abstractmethod
def requires_lowering(self):
pass

@property
@abc.abstractmethod
def local_tmpdir(self) -> str:
pass

@local_tmpdir.setter
@abc.abstractmethod
def local_tmpdir(self, dir: str) -> None:
pass

@property
@abc.abstractmethod
def remote_tmpdir(self) -> str:
pass

@remote_tmpdir.setter
@abc.abstractmethod
def remote_tmpdir(self, dir: str) -> None:
pass
22 changes: 20 additions & 2 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def decode_bytearray(encoded):
self._jhc = jhc

self._jbackend = self._hail_package.backend.driver.Py4JQueryDriver(jbackend)
self._jbackend.pySetLocalTmp(tmpdir)
self._jbackend.pySetRemoteTmp(remote_tmpdir)
self.local_tmpdir = tmpdir
self.remote_tmpdir = remote_tmpdir

self._jhttp_server = self._jbackend.pyHttpServer()

Expand Down Expand Up @@ -341,3 +341,21 @@ def stop(self):
self._jhc = None
uninstall_exception_handler()
super().stop()

@property
def local_tmpdir(self) -> str:
return self._local_tmpdir

@local_tmpdir.setter
def local_tmpdir(self, tmpdir: str) -> None:
self._local_tmpdir = tmpdir
self._jbackend.pySetLocalTmp(tmpdir)

@property
def remote_tmpdir(self) -> str:
return self._remote_tmpdir

@remote_tmpdir.setter
def remote_tmpdir(self, tmpdir: str) -> None:
self._remote_tmpdir = tmpdir
self._jbackend.pySetRemoteTmp(tmpdir)
20 changes: 18 additions & 2 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from contextlib import AsyncExitStack
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
from typing import Any, Awaitable, Dict, List, Mapping, NoReturn, Optional, Set, Tuple, TypeVar, Union

import orjson

Expand Down Expand Up @@ -230,7 +230,7 @@ def __init__(
self._batch = batch
self._job_group_was_submitted: bool = False
self.disable_progress_bar = disable_progress_bar
self.remote_tmpdir = remote_tmpdir
self._remote_tmpdir = remote_tmpdir
self.flags: Dict[str, str] = {}
self._registered_ir_function_names: Set[str] = set()
self.driver_cores = driver_cores
Expand Down Expand Up @@ -502,3 +502,19 @@ def get_flags(self, *flags: str) -> Mapping[str, str]:
@property
def requires_lowering(self):
return True

@property
def local_tmpdir(self) -> NoReturn:
raise AttributeError('local tmp folders are not supported on the batch backend')

@local_tmpdir.setter
def local_tmpdir(self, tmpdir: str) -> NoReturn:
raise AttributeError('local tmp folders are not supported on the batch backend')

@property
def remote_tmpdir(self) -> str:
return self._remote_tmpdir

@remote_tmpdir.setter
def remote_tmpdir(self, tmpdir: str) -> None:
self._remote_tmpdir = tmpdir
35 changes: 13 additions & 22 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def create(
log: str,
quiet: bool,
append: bool,
tmpdir: str,
local_tmpdir: str,
default_reference: str,
global_seed: Optional[int],
backend: Backend,
Expand All @@ -76,25 +74,17 @@ def create(
log=log,
quiet=quiet,
append=append,
tmpdir=tmpdir,
local_tmpdir=local_tmpdir,
global_seed=global_seed,
backend=backend,
)
hc.initialize_references(default_reference)
return hc

@typecheck_method(
log=str, quiet=bool, append=bool, tmpdir=str, local_tmpdir=str, global_seed=nullable(int), backend=Backend
)
def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backend):
@typecheck_method(log=str, quiet=bool, append=bool, global_seed=nullable(int), backend=Backend)
def __init__(self, log, quiet, append, global_seed, backend: Backend):
assert not Env._hc

self._log = log

self._tmpdir = tmpdir
self._local_tmpdir = local_tmpdir

self._backend = backend

self._warn_cols_order = True
Expand Down Expand Up @@ -136,6 +126,14 @@ def initialize_references(self, default_reference):
else:
self._default_ref = ReferenceGenome.read(default_reference)

@property
def _tmpdir(self) -> str:
return self._backend.remote_tmpdir

@property
def _local_tmpdir(self) -> str:
return self._backend.local_tmpdir

@property
def default_reference(self) -> ReferenceGenome:
assert self._default_ref is not None, '_default_ref should have been initialized in HailContext.create'
Expand Down Expand Up @@ -376,7 +374,6 @@ def init(
quiet=quiet,
append=append,
tmpdir=tmp_dir,
local_tmpdir=local_tmpdir,
default_reference=default_reference,
global_seed=global_seed,
driver_cores=driver_cores,
Expand Down Expand Up @@ -503,7 +500,7 @@ def init_spark(
if not backend.fs.exists(tmpdir):
backend.fs.mkdir(tmpdir)

HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
if not quiet:
connect_logger(backend._utils_package_object, 'localhost', 12888)

Expand All @@ -515,7 +512,6 @@ def init_spark(
quiet=bool,
append=bool,
tmpdir=nullable(str),
local_tmpdir=nullable(str),
default_reference=enumeration(*BUILTIN_REFERENCES),
global_seed=nullable(int),
disable_progress_bar=nullable(bool),
Expand All @@ -538,7 +534,6 @@ async def init_batch(
quiet: bool = False,
append: bool = False,
tmpdir: Optional[str] = None,
local_tmpdir: Optional[str] = None,
default_reference: str = 'GRCh37',
global_seed: Optional[int] = None,
disable_progress_bar: Optional[bool] = None,
Expand Down Expand Up @@ -573,11 +568,7 @@ async def init_batch(
)

log = _get_log(log)
if tmpdir is None:
tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string())
local_tmpdir = _get_local_tmpdir(local_tmpdir)

HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
HailContext.create(log, quiet, append, default_reference, global_seed, backend)


@typecheck(
Expand Down Expand Up @@ -629,7 +620,7 @@ def init_local(
if not backend.fs.exists(tmpdir):
backend.fs.mkdir(tmpdir)

HailContext.create(log, quiet, append, tmpdir, tmpdir, default_reference, global_seed, backend)
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
if not quiet:
connect_logger(backend._utils_package_object, 'localhost', 12888)

Expand Down