Skip to content

Commit 54a8581

Browse files
committed
fixes for running benchmarks with no arguments
1 parent 7038c29 commit 54a8581

File tree

5 files changed

+47
-18
lines changed

5 files changed

+47
-18
lines changed

hail/python/benchmark/hail/benchmark_shuffle.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
1+
import pytest
2+
13
import hail as hl
24
from benchmark.tools import benchmark
35

46

7+
@pytest.fixture(autouse=True)
8+
def new_query_tmpdir(tmp_path):
9+
backend = hl.current_backend()
10+
old = backend.local_tmpdir
11+
backend.local_tmpdir = str(tmp_path)
12+
try:
13+
yield
14+
finally:
15+
backend.local_tmpdir = old
16+
17+
518
@benchmark()
619
def benchmark_shuffle_key_rows_by_mt(profile25_mt):
720
mt = hl.read_matrix_table(str(profile25_mt))

hail/python/benchmark/hail/conftest.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def init_hail(run_config):
111111
context = pytest.StashKey[Literal['burn_in', 'benchmark']]()
112112

113113

114+
def prune(kvs: dict) -> dict:
115+
return {k: v for k, v in kvs.items() if v is not None}
116+
117+
114118
@pytest.hookimpl(tryfirst=True)
115119
def pytest_runtest_protocol(item, nextitem):
116120
run_config = item.session.config.run_config
@@ -184,14 +188,16 @@ def pytest_pyfunc_call(pyfuncitem):
184188

185189
is_burn_in = s[context] == 'burn_in'
186190

187-
s[runs].append({
188-
'iteration': s[iteration],
189-
'time': time,
190-
'is_burn_in': is_burn_in,
191-
'timed_out': timed_out,
192-
'failure': traceback,
193-
'task_memory': get_peak_task_memory(Env.hc()._log),
194-
})
191+
s[runs].append(
192+
prune({
193+
'iteration': s[iteration],
194+
'time': time,
195+
'is_burn_in': is_burn_in,
196+
'timed_out': timed_out,
197+
'failure': traceback,
198+
'task_memory': get_peak_task_memory(Env.hc()._log),
199+
})
200+
)
195201

196202
logging.info(f'{"(burn in) " if is_burn_in else ""}iteration {s[iteration]}, time: {time}s')
197203

@@ -210,30 +216,30 @@ def open_file_or_stdout(file):
210216

211217
@pytest.hookimpl
212218
def pytest_sessionfinish(session):
213-
if not session.config.option.collectonly:
219+
if hasattr(session, 'items') and len(session.items) > 0 and not session.config.option.collectonly:
214220
run_config = session.config.run_config
215221

216222
meta = {
217223
'uname': platform.uname()._asdict(),
218224
'version': hl.__version__,
219-
**({'batch_id': batch} if (batch := os.getenv('HAIL_BATCH_ID')) else {}),
220-
**({'job_id': job} if (job := os.getenv('HAIL_JOB_ID')) else {}),
221-
**({'trial': trial} if (trial := os.getenv('BENCHMARK_TRIAL_ID')) else {}),
225+
'batch_id': os.getenv('HAIL_BATCH_ID'),
226+
'job_id': os.getenv('HAIL_JOB_ID'),
227+
'trial': os.getenv('BENCHMARK_TRIAL_ID'),
222228
}
223229

224230
now = datetime.now(timezone.utc).isoformat()
225231
with open_file_or_stdout(run_config.output) as out:
226232
for item in session.items:
227233
path, _, name = item.location
228234
json.dump(
229-
{
235+
prune({
230236
'path': path,
231237
'name': name,
232238
**meta,
233239
'start': item.stash[start],
234240
'end': item.stash.get(end, now),
235241
'runs': item.stash[runs],
236-
},
242+
}),
237243
out,
238244
)
239245
out.write('\n')

hail/python/benchmark/hail/fixtures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def resource_dir(request, tmpdir_factory):
1515
resource_dir = Path(run_config.data_dir)
1616
resource_dir.mkdir(parents=True, exist_ok=True)
1717
else:
18-
resource_dir = tmpdir_factory.mktemp('hail_benchmark_resources')
18+
resource_dir = Path(tmpdir_factory.mktemp('hail_benchmark_resources'))
1919

2020
return resource_dir
2121

@@ -35,7 +35,7 @@ def __download(data_dir, filename):
3535

3636
def localize(path: Path):
3737
if not path.exists():
38-
path.parent.mkdir(exist_ok=True)
38+
path.parent.mkdir(parents=True, exist_ok=True)
3939
__download(path.parent, path.name)
4040

4141
return path

hail/python/hail/backend/backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,17 @@ def requires_lowering(self):
398398
def local_tmpdir(self) -> str:
399399
pass
400400

401+
@local_tmpdir.setter
402+
@abc.abstractmethod
403+
def local_tmpdir(self, dir: str) -> None:
404+
pass
405+
401406
@property
402407
@abc.abstractmethod
403408
def remote_tmpdir(self) -> str:
404409
pass
410+
411+
@remote_tmpdir.setter
412+
@abc.abstractmethod
413+
def remote_tmpdir(self, dir: str) -> None:
414+
pass

hail/python/hail/backend/py4j_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def local_tmpdir(self) -> str:
332332
return self._local_tmpdir
333333

334334
@local_tmpdir.setter
335-
def local_tmpdir(self, tmpdir) -> str:
335+
def local_tmpdir(self, tmpdir: str) -> None:
336336
self._local_tmpdir = tmpdir
337337
self._jbackend.pySetLocalTmp(tmpdir)
338338

@@ -341,6 +341,6 @@ def remote_tmpdir(self) -> str:
341341
return self._remote_tmpdir
342342

343343
@remote_tmpdir.setter
344-
def remote_tmpdir(self, tmpdir) -> str:
344+
def remote_tmpdir(self, tmpdir: str) -> None:
345345
self._remote_tmpdir = tmpdir
346346
self._jbackend.pySetRemoteTmp(tmpdir)

0 commit comments

Comments
 (0)