Skip to content

Commit

Permalink
gh-128308: pass **kwargs to asyncio task_factory (#128768)
Browse files Browse the repository at this point in the history
Co-authored-by: Kumar Aditya <[email protected]>
  • Loading branch information
graingert and kumaraditya303 authored Jan 20, 2025
1 parent 6c914bf commit 38a9956
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 29 deletions.
4 changes: 2 additions & 2 deletions Doc/library/asyncio-eventloop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,9 @@ Creating Futures and Tasks

If *factory* is ``None`` the default task factory will be set.
Otherwise, *factory* must be a *callable* with the signature matching
``(loop, coro, context=None)``, where *loop* is a reference to the active
``(loop, coro, **kwargs)``, where *loop* is a reference to the active
event loop, and *coro* is a coroutine object. The callable
must return a :class:`asyncio.Future`-compatible object.
must pass on all *kwargs*, and return a :class:`asyncio.Task`-compatible object.

.. method:: loop.get_task_factory()

Expand Down
26 changes: 10 additions & 16 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,25 +458,18 @@ def create_future(self):
"""Create a Future object attached to the loop."""
return futures.Future(loop=self)

def create_task(self, coro, *, name=None, context=None):
def create_task(self, coro, **kwargs):
"""Schedule a coroutine object.
Return a task object.
"""
self._check_closed()
if self._task_factory is None:
task = tasks.Task(coro, loop=self, name=name, context=context)
if task._source_traceback:
del task._source_traceback[-1]
else:
if context is None:
# Use legacy API if context is not needed
task = self._task_factory(self, coro)
else:
task = self._task_factory(self, coro, context=context)

task.set_name(name)
if self._task_factory is not None:
return self._task_factory(self, coro, **kwargs)

task = tasks.Task(coro, loop=self, **kwargs)
if task._source_traceback:
del task._source_traceback[-1]
try:
return task
finally:
Expand All @@ -490,9 +483,10 @@ def set_task_factory(self, factory):
If factory is None the default task factory will be set.
If factory is a callable, it should have a signature matching
'(loop, coro)', where 'loop' will be a reference to the active
event loop, 'coro' will be a coroutine object. The callable
must return a Future.
'(loop, coro, **kwargs)', where 'loop' will be a reference to the active
event loop, 'coro' will be a coroutine object, and **kwargs will be
arbitrary keyword arguments that should be passed on to Task.
The callable must return a Task.
"""
if factory is not None and not callable(factory):
raise TypeError('task factory must be a callable or None')
Expand Down
2 changes: 1 addition & 1 deletion Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def create_future(self):

# Method scheduling a coroutine object: create a task.

def create_task(self, coro, *, name=None, context=None):
def create_task(self, coro, **kwargs):
raise NotImplementedError

# Methods for interacting with threads.
Expand Down
4 changes: 2 additions & 2 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,8 @@ async def test():
loop.close()

def test_create_named_task_with_custom_factory(self):
def task_factory(loop, coro):
return asyncio.Task(coro, loop=loop)
def task_factory(loop, coro, **kwargs):
return asyncio.Task(coro, loop=loop, **kwargs)

async def test():
pass
Expand Down
12 changes: 12 additions & 0 deletions Lib/test/test_asyncio/test_eager_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ async def run():

self.run_coro(run())

def test_name(self):
name = None
async def coro():
nonlocal name
name = asyncio.current_task().get_name()

async def main():
task = self.loop.create_task(coro(), name="test name")
self.assertEqual(name, "test name")
await task

self.run_coro(coro())

class AsyncTaskCounter:
def __init__(self, loop, *, task_class, eager):
Expand Down
16 changes: 8 additions & 8 deletions Lib/test/test_asyncio/test_free_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,25 @@ class TestPyFreeThreading(TestFreeThreading, TestCase):
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
current_task = staticmethod(asyncio.tasks._py_current_task)

def factory(self, loop, coro, context=None):
return asyncio.tasks._PyTask(coro, loop=loop, context=context)
def factory(self, loop, coro, **kwargs):
return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)


@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
class TestCFreeThreading(TestFreeThreading, TestCase):
all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))

def factory(self, loop, coro, context=None):
return asyncio.tasks._CTask(coro, loop=loop, context=context)
def factory(self, loop, coro, **kwargs):
return asyncio.tasks._CTask(coro, loop=loop, **kwargs)


class TestEagerPyFreeThreading(TestPyFreeThreading):
def factory(self, loop, coro, context=None):
return asyncio.tasks._PyTask(coro, loop=loop, context=context, eager_start=True)
def factory(self, loop, coro, eager_start=True, **kwargs):
return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, eager_start=eager_start)


@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
def factory(self, loop, coro, context=None):
return asyncio.tasks._CTask(coro, loop=loop, context=context, eager_start=True)
def factory(self, loop, coro, eager_start=True, **kwargs):
return asyncio.tasks._CTask(coro, loop=loop, **kwargs, eager_start=eager_start)
12 changes: 12 additions & 0 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,18 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
self.assertIsNotNone(exc)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())

async def test_name(self):
name = None

async def asyncfn():
nonlocal name
name = asyncio.current_task().get_name()

async with asyncio.TaskGroup() as tg:
tg.create_task(asyncfn(), name="example name")

self.assertEqual(name, "example name")


class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
loop_factory = asyncio.EventLoop
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support the *name* keyword argument for eager tasks in :func:`asyncio.loop.create_task`, :func:`asyncio.create_task` and :func:`asyncio.TaskGroup.create_task`, by passing on all *kwargs* to the task factory set by :func:`asyncio.loop.set_task_factory`.

0 comments on commit 38a9956

Please sign in to comment.