diff --git a/src/executorlib/executor/base.py b/src/executorlib/executor/base.py index 281d4d8d..2925549f 100644 --- a/src/executorlib/executor/base.py +++ b/src/executorlib/executor/base.py @@ -107,6 +107,43 @@ def submit( # type: ignore else: raise RuntimeError("cannot schedule new futures after shutdown") + def map( + self, + fn: Callable, + *iterables, + timeout: Optional[float] = None, + chunksize: int = 1, + ): + """Returns an iterator equivalent to map(fn, iter). + + Args: + fn: A callable that will take as many arguments as there are + passed iterables. + timeout: The maximum number of seconds to wait. If None, then there + is no limit on the wait time. + chunksize: The size of the chunks the iterable will be broken into + before being passed to a child process. This argument is only + used by ProcessPoolExecutor; it is ignored by + ThreadPoolExecutor. + + Returns: + An iterator equivalent to: map(func, *iterables) but the calls may + be evaluated out-of-order. + + Raises: + TimeoutError: If the entire result iterator could not be generated + before the given timeout. + Exception: If fn(*args) raises for any values. + """ + if self._is_active: + return self._task_scheduler.map( + *([fn] + list(iterables)), + timeout=timeout, + chunksize=chunksize, + ) + else: + raise RuntimeError("cannot schedule new futures after shutdown") + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): """ Clean-up the resources associated with the Executor. diff --git a/src/executorlib/task_scheduler/base.py b/src/executorlib/task_scheduler/base.py index 0961318d..8940fd7a 100644 --- a/src/executorlib/task_scheduler/base.py +++ b/src/executorlib/task_scheduler/base.py @@ -143,6 +143,43 @@ def submit( # type: ignore ) return f + def map( + self, + fn: Callable, + *iterables, + timeout: Optional[float] = None, + chunksize: int = 1, + ): + """Returns an iterator equivalent to map(fn, iter). + + Args: + fn: A callable that will take as many arguments as there are + passed iterables. + timeout: The maximum number of seconds to wait. If None, then there + is no limit on the wait time. + chunksize: The size of the chunks the iterable will be broken into + before being passed to a child process. This argument is only + used by ProcessPoolExecutor; it is ignored by + ThreadPoolExecutor. + + Returns: + An iterator equivalent to: map(func, *iterables) but the calls may + be evaluated out-of-order. + + Raises: + TimeoutError: If the entire result iterator could not be generated + before the given timeout. + Exception: If fn(*args) raises for any values. + """ + if isinstance(iterables, (list, tuple)) and any( + isinstance(i, Future) for i in iterables + ): + iterables = tuple( + i.result() if isinstance(i, Future) else i for i in iterables + ) + + return super().map(fn, *iterables, timeout=timeout, chunksize=chunksize) + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): """ Clean-up the resources associated with the Executor. diff --git a/tests/test_singlenodeexecutor_mpi.py b/tests/test_singlenodeexecutor_mpi.py index 09bda08a..61908343 100644 --- a/tests/test_singlenodeexecutor_mpi.py +++ b/tests/test_singlenodeexecutor_mpi.py @@ -125,6 +125,21 @@ def test_output_files_cwd(self): [1, 2, 3], ) + def test_map_futures(self): + dirname = os.path.abspath(os.path.dirname(__file__)) + os.makedirs(dirname, exist_ok=True) + with SingleNodeExecutor( + max_cores=1, + resource_dict={"cores": 1, "cwd": dirname}, + block_allocation=True, + ) as p: + calc_lst = p.submit(calc, [1, 2, 3]) + output = list(p.map(calc, calc_lst)) + self.assertEqual( + output, + [1, 2, 3], + ) + class TestSLURMExecutor(unittest.TestCase): def test_validate_max_workers(self): diff --git a/tests/test_singlenodeexecutor_noblock.py b/tests/test_singlenodeexecutor_noblock.py index 06486330..b0606412 100644 --- a/tests/test_singlenodeexecutor_noblock.py +++ b/tests/test_singlenodeexecutor_noblock.py @@ -173,3 +173,5 @@ def test_single_node_executor_exit(self): exe.shutdown() with self.assertRaises(RuntimeError): exe.submit(sum, [1, 2, 3]) + with self.assertRaises(RuntimeError): + exe.map(calc, [1, 2, 3])