From 9e0a3975043b0eeb12ca68ac7b951bd034a5cb37 Mon Sep 17 00:00:00 2001 From: Alexey Semenyuk Date: Sun, 27 Jul 2025 00:07:34 +0500 Subject: [PATCH] gh-137128: support for async iterables in coro_fns --- Lib/asyncio/staggered.py | 23 ++- Lib/test/test_asyncio/test_staggered.py | 197 ++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 5 deletions(-) diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 2ad65d8648e6c5..e7191b5d34bfe4 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -38,7 +38,7 @@ async def staggered_race(coro_fns, delay, *, loop=None): Args: coro_fns: an iterable of coroutine functions, i.e. callables that return a coroutine object when called. Use ``functools.partial`` or - lambdas to pass arguments. + lambdas to pass arguments. Can also be an async iterable. delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. @@ -62,10 +62,19 @@ async def staggered_race(coro_fns, delay, *, loop=None): coroutine's entry is ``None``. """ - # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + # Support for async iterables in coro_fns + try: + # Try to get an async iterator + aiter_coro_fns = aiter(coro_fns) + is_async_iterable = True + enum_coro_fns = None + except TypeError: + # Not an async iterable, use regular iteration + enum_coro_fns = enumerate(coro_fns) + is_async_iterable = False + aiter_coro_fns = None loop = loop or events.get_running_loop() parent_task = tasks.current_task(loop) - enum_coro_fns = enumerate(coro_fns) winner_result = None winner_index = None unhandled_exceptions = [] @@ -106,8 +115,12 @@ async def run_one_coro(ok_to_start, previous_failed) -> None: await tasks.wait_for(previous_failed.wait(), delay) # Get the next coroutine to run try: - this_index, coro_fn = next(enum_coro_fns) - except StopIteration: + if is_async_iterable: + coro_fn = await anext(aiter_coro_fns) + this_index = len(exceptions) # Track index manually for async iterables + else: + this_index, coro_fn = next(enum_coro_fns) + except (StopIteration, StopAsyncIteration): return # Start task that will run the next coroutine this_failed = locks.Event() diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index ad34aa6da01f54..3178fbd812fd0f 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -149,3 +149,200 @@ async def coro_fn(): raise self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"]) + + async def test_async_iterable_empty(self): + async def empty_async_iterable(): + if False: + yield lambda: asyncio.sleep(0) + + winner, index, excs = await staggered_race( + empty_async_iterable(), + delay=None, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(excs, []) + + async def test_async_iterable_one_successful(self): + async def async_coro_generator(): + async def coro(index): + return f'Async Res: {index}' + + yield lambda: coro(0) + yield lambda: coro(1) + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=None, + ) + + self.assertEqual(winner, 'Async Res: 0') + self.assertEqual(index, 0) + self.assertEqual(excs, [None]) + + async def test_async_iterable_first_error_second_successful(self): + async def async_coro_generator(): + async def coro(index): + if index == 0: + raise ValueError(f'Async Error: {index}') + return f'Async Res: {index}' + + yield lambda: coro(0) + yield lambda: coro(1) + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=None, + ) + + self.assertEqual(winner, 'Async Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(str(excs[0]), 'Async Error: 0') + self.assertIs(excs[1], None) + + async def test_async_iterable_first_timeout_second_successful(self): + async def async_coro_generator(): + async def coro(index): + if index == 0: + await asyncio.sleep(10) + return f'Async Res: {index}' + + yield lambda: coro(0) + yield lambda: coro(1) + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=0.1, + ) + + self.assertEqual(winner, 'Async Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIs(excs[1], None) + + async def test_async_iterable_none_successful(self): + async def async_coro_generator(): + async def coro(index): + raise ValueError(f'Async Error: {index}') + + yield lambda: coro(0) + yield lambda: coro(1) + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=None, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(str(excs[0]), 'Async Error: 0') + self.assertIsInstance(excs[1], ValueError) + self.assertEqual(str(excs[1]), 'Async Error: 1') + + async def test_async_iterable_multiple_winners(self): + event = asyncio.Event() + + async def async_coro_generator(): + async def coro(index): + await event.wait() + return f'Async Index: {index}' + + async def do_set(): + event.set() + await asyncio.Event().wait() + + yield lambda: coro(0) + yield lambda: coro(1) + yield do_set + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=0.1, + ) + + self.assertEqual(winner, 'Async Index: 0') + self.assertEqual(index, 0) + self.assertEqual(len(excs), 3) + self.assertIsNone(excs[0]) + self.assertIsInstance(excs[1], asyncio.CancelledError) + self.assertIsInstance(excs[2], asyncio.CancelledError) + + async def test_async_iterable_with_delay(self): + results = [] + + async def async_coro_generator(): + async def coro(index): + results.append(f'Started: {index}') + await asyncio.sleep(0.05) + return f'Result: {index}' + + yield lambda: coro(0) + yield lambda: coro(1) + yield lambda: coro(2) + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=0.02, + ) + + self.assertEqual(winner, 'Result: 0') + self.assertEqual(index, 0) + + self.assertGreaterEqual(len(excs), 1) + self.assertIsNone(excs[0]) + + self.assertIn('Started: 0', results) + + async def test_async_iterable_mixed_with_regular(self): + async def coro(index): + return f'Mixed Res: {index}' + + winner, index, excs = await staggered_race( + [lambda: coro(0), lambda: coro(1)], + delay=None, + ) + + self.assertEqual(winner, 'Mixed Res: 0') + self.assertEqual(index, 0) + self.assertEqual(excs, [None]) + + async def test_async_iterable_cancelled(self): + log = [] + + async def async_coro_generator(): + async def coro_fn(): + try: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + log.append("async cancelled") + raise + + yield coro_fn + + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01): + await staggered_race(async_coro_generator(), delay=None) + + self.assertListEqual(log, ["async cancelled"]) + + async def test_async_iterable_stop_async_iteration(self): + async def async_coro_generator(): + async def coro(): + return "success" + + yield lambda: coro() + + winner, index, excs = await staggered_race( + async_coro_generator(), + delay=None, + ) + + self.assertEqual(winner, "success") + self.assertEqual(index, 0) + self.assertEqual(excs, [None])