Skip to content

gh-137128: Support for async iterables in coro_fns #137129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions Lib/asyncio/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def staggered_race(coro_fns, delay, *, loop=None):
Args:
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.
lambdas to pass arguments. Can also be an async iterable.

delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
Expand All @@ -62,10 +62,19 @@ async def staggered_race(coro_fns, delay, *, loop=None):
coroutine's entry is ``None``.

"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
# Support for async iterables in coro_fns
try:
# Try to get an async iterator
aiter_coro_fns = aiter(coro_fns)
is_async_iterable = True
enum_coro_fns = None
except TypeError:
# Not an async iterable, use regular iteration
enum_coro_fns = enumerate(coro_fns)
is_async_iterable = False
aiter_coro_fns = None
loop = loop or events.get_running_loop()
parent_task = tasks.current_task(loop)
enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
unhandled_exceptions = []
Expand Down Expand Up @@ -106,8 +115,12 @@ async def run_one_coro(ok_to_start, previous_failed) -> None:
await tasks.wait_for(previous_failed.wait(), delay)
# Get the next coroutine to run
try:
this_index, coro_fn = next(enum_coro_fns)
except StopIteration:
if is_async_iterable:
coro_fn = await anext(aiter_coro_fns)
this_index = len(exceptions) # Track index manually for async iterables
else:
this_index, coro_fn = next(enum_coro_fns)
except (StopIteration, StopAsyncIteration):
return
# Start task that will run the next coroutine
this_failed = locks.Event()
Expand Down
197 changes: 197 additions & 0 deletions Lib/test/test_asyncio/test_staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,200 @@ async def coro_fn():
raise

self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])

async def test_async_iterable_empty(self):
async def empty_async_iterable():
if False:
yield lambda: asyncio.sleep(0)

winner, index, excs = await staggered_race(
empty_async_iterable(),
delay=None,
)

self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(excs, [])

async def test_async_iterable_one_successful(self):
async def async_coro_generator():
async def coro(index):
return f'Async Res: {index}'

yield lambda: coro(0)
yield lambda: coro(1)

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=None,
)

self.assertEqual(winner, 'Async Res: 0')
self.assertEqual(index, 0)
self.assertEqual(excs, [None])

async def test_async_iterable_first_error_second_successful(self):
async def async_coro_generator():
async def coro(index):
if index == 0:
raise ValueError(f'Async Error: {index}')
return f'Async Res: {index}'

yield lambda: coro(0)
yield lambda: coro(1)

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=None,
)

self.assertEqual(winner, 'Async Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(str(excs[0]), 'Async Error: 0')
self.assertIs(excs[1], None)

async def test_async_iterable_first_timeout_second_successful(self):
async def async_coro_generator():
async def coro(index):
if index == 0:
await asyncio.sleep(10)
return f'Async Res: {index}'

yield lambda: coro(0)
yield lambda: coro(1)

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=0.1,
)

self.assertEqual(winner, 'Async Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIs(excs[1], None)

async def test_async_iterable_none_successful(self):
async def async_coro_generator():
async def coro(index):
raise ValueError(f'Async Error: {index}')

yield lambda: coro(0)
yield lambda: coro(1)

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=None,
)

self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(str(excs[0]), 'Async Error: 0')
self.assertIsInstance(excs[1], ValueError)
self.assertEqual(str(excs[1]), 'Async Error: 1')

async def test_async_iterable_multiple_winners(self):
event = asyncio.Event()

async def async_coro_generator():
async def coro(index):
await event.wait()
return f'Async Index: {index}'

async def do_set():
event.set()
await asyncio.Event().wait()

yield lambda: coro(0)
yield lambda: coro(1)
yield do_set

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=0.1,
)

self.assertEqual(winner, 'Async Index: 0')
self.assertEqual(index, 0)
self.assertEqual(len(excs), 3)
self.assertIsNone(excs[0])
self.assertIsInstance(excs[1], asyncio.CancelledError)
self.assertIsInstance(excs[2], asyncio.CancelledError)

async def test_async_iterable_with_delay(self):
results = []

async def async_coro_generator():
async def coro(index):
results.append(f'Started: {index}')
await asyncio.sleep(0.05)
return f'Result: {index}'

yield lambda: coro(0)
yield lambda: coro(1)
yield lambda: coro(2)

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=0.02,
)

self.assertEqual(winner, 'Result: 0')
self.assertEqual(index, 0)

self.assertGreaterEqual(len(excs), 1)
self.assertIsNone(excs[0])

self.assertIn('Started: 0', results)

async def test_async_iterable_mixed_with_regular(self):
async def coro(index):
return f'Mixed Res: {index}'

winner, index, excs = await staggered_race(
[lambda: coro(0), lambda: coro(1)],
delay=None,
)

self.assertEqual(winner, 'Mixed Res: 0')
self.assertEqual(index, 0)
self.assertEqual(excs, [None])

async def test_async_iterable_cancelled(self):
log = []

async def async_coro_generator():
async def coro_fn():
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
log.append("async cancelled")
raise

yield coro_fn

with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01):
await staggered_race(async_coro_generator(), delay=None)

self.assertListEqual(log, ["async cancelled"])

async def test_async_iterable_stop_async_iteration(self):
async def async_coro_generator():
async def coro():
return "success"

yield lambda: coro()

winner, index, excs = await staggered_race(
async_coro_generator(),
delay=None,
)

self.assertEqual(winner, "success")
self.assertEqual(index, 0)
self.assertEqual(excs, [None])
Loading