Skip to content

Commit 6680304

Browse files
shell scheduler: --use-srun option
1 parent 9c97feb commit 6680304

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

pytest_parallel/plugin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def pytest_addoption(parser):
3535

3636
if sys.version_info >= (3,9):
3737
parser.addoption('--slurm-export-env', dest='slurm_export_env', action=argparse.BooleanOptionalAction, default=True)
38+
parser.addoption('--use-srun', dest='use_srun', action=argparse.BooleanOptionalAction, default=None, help='Launch MPI processes through srun (only possible when `--scheduler=shell`')
3839
else:
3940
parser.addoption('--slurm-export-env', dest='slurm_export_env', default=False, action='store_true')
4041
parser.addoption('--no-slurm-export-env', dest='slurm_export_env', action='store_false')
42+
parser.addoption('--use-srun', dest='use_srun', default=None, action='store_true')
4143

4244
parser.addoption('--detach', dest='detach', action='store_true', help='Detach SLURM jobs: do not send reports to the scheduling process (useful to launch slurm job.sh separately)')
4345

@@ -104,7 +106,11 @@ def pytest_configure(config):
104106
is_worker = config.getoption('_worker')
105107
slurm_file = config.getoption('slurm_file')
106108
slurm_export_env = config.getoption('slurm_export_env')
109+
use_srun = config.getoption('use_srun')
107110
detach = config.getoption('detach')
111+
if scheduler != 'shell':
112+
if use_srun is not None:
113+
raise PytestParallelError('Option `--use-srun` only available when `--scheduler=shell`')
108114
if not scheduler in ['slurm', 'shell']:
109115
assert not is_worker, f'Internal pytest_parallel error `--_worker` not available with`--scheduler={scheduler}`'
110116
assert not n_workers, f'pytest_parallel error `--n-workers` not available with`--scheduler={scheduler}`. Launch with `mpirun -np {n_workers}` to run in parallel'
@@ -175,7 +181,7 @@ def pytest_configure(config):
175181
main_invoke_params = _invoke_params(config.invocation_params.args)
176182
for file_or_dir in config.option.file_or_dir:
177183
main_invoke_params = main_invoke_params.replace(file_or_dir, '')
178-
plugin = ShellStaticScheduler(main_invoke_params, n_workers, detach)
184+
plugin = ShellStaticScheduler(main_invoke_params, n_workers, detach, use_srun)
179185
else:
180186
from mpi4py import MPI
181187
from .mpi_reporter import SequentialScheduler, StaticScheduler, DynamicScheduler

pytest_parallel/shell_static_scheduler.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@
1313
from .utils.file import remove_exotic_chars, create_folders
1414
from .static_scheduler_utils import group_items_by_parallel_steps
1515

16-
def mpi_command(current_proc, n_proc):
17-
mpi_vendor = MPI.get_vendor()[0]
18-
if mpi_vendor == 'Intel MPI':
19-
cmd = f'I_MPI_PIN_PROCESSOR_LIST={current_proc}-{current_proc+n_proc-1}; '
20-
cmd += f'mpiexec -np {n_proc}'
21-
return cmd
22-
elif mpi_vendor == 'Open MPI':
23-
cores = ','.join([str(i) for i in range(current_proc,current_proc+n_proc)])
24-
return f'mpiexec --cpu-list {cores} -np {n_proc}'
16+
def mpi_command(current_proc, n_proc, use_srun):
17+
if use_srun:
18+
return f'srun --exact --ntasks={n_proc}'
2519
else:
26-
assert 0, f'Unknown MPI implementation "{mpi_vendor}"'
20+
mpi_vendor = MPI.get_vendor()[0]
21+
if mpi_vendor == 'Intel MPI':
22+
cmd = f'I_MPI_PIN_PROCESSOR_LIST={current_proc}-{current_proc+n_proc-1}; '
23+
cmd += f'mpiexec -np {n_proc}'
24+
return cmd
25+
elif mpi_vendor == 'Open MPI':
26+
cores = ','.join([str(i) for i in range(current_proc,current_proc+n_proc)])
27+
return f'mpiexec --cpu-list {cores} -np {n_proc}'
28+
else:
29+
assert 0, f'Unknown MPI implementation "{mpi_vendor}"'
2730

28-
def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, session_folder, main_invoke_params, i_step, n_step):
31+
def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, session_folder, main_invoke_params, use_srun, i_step, n_step):
2932
# sort item by comm size to launch bigger first (Note: in case SLURM prioritize first-received items)
3033
items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True)
3134

@@ -40,7 +43,7 @@ def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, session_folder, main_
4043
test_idx = item.original_index
4144
test_out_file = f'.pytest_parallel/{session_folder}/{remove_exotic_chars(item.nodeid)}'
4245
cmd = '('
43-
cmd += mpi_command(current_proc, item.n_proc)
46+
cmd += mpi_command(current_proc, item.n_proc, use_srun)
4447
cmd += f' python3 -u -m pytest -s --_worker {socket_flags} {main_invoke_params} --_test_idx={test_idx} {item.config.rootpath}/{item.nodeid}'
4548
cmd += f' > {test_out_file} 2>&1'
4649
cmd += f' ; python3 -m pytest_parallel.send_report {socket_flags} --_test_idx={test_idx} --_test_name={test_out_file}'
@@ -99,10 +102,11 @@ def receive_items(items, session, socket, n_item_to_recv):
99102
n_item_to_recv -= 1
100103

101104
class ShellStaticScheduler:
102-
def __init__(self, main_invoke_params, ntasks, detach):
105+
def __init__(self, main_invoke_params, ntasks, detach, use_srun):
103106
self.main_invoke_params = main_invoke_params
104107
self.ntasks = ntasks
105108
self.detach = detach
109+
self.use_srun = use_srun
106110

107111
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TODO close at the end
108112

@@ -148,7 +152,7 @@ def pytest_runtestloop(self, session) -> bool:
148152
n_step = len(items_by_steps)
149153
for i_step,items in enumerate(items_by_steps):
150154
n_item_to_receive = len(items)
151-
sub_process = submit_items(items, SCHEDULER_IP_ADDRESS, port, session_folder, self.main_invoke_params, i_step, n_step)
155+
sub_process = submit_items(items, SCHEDULER_IP_ADDRESS, port, session_folder, self.main_invoke_params, self.use_srun, i_step, n_step)
152156
if not self.detach: # The job steps are supposed to send their reports
153157
receive_items(session.items, session, self.socket, n_item_to_receive)
154158
returncode = sub_process.wait() # at this point, the sub-process should be done since items have been received

0 commit comments

Comments
 (0)