Skip to content

Commit

Permalink
get demo running
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Jan 3, 2025
1 parent 404c93d commit 14e0a83
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 22 deletions.
13 changes: 9 additions & 4 deletions taskgroup/install.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
import contextvars
import asyncio
import collections.abc
import contextlib
import types
from typing import cast
Expand All @@ -9,6 +9,11 @@

from typing_extensions import Self, TypeVar

if sys.version_info >= (3, 9):
from collections.abc import Generator, Coroutine
else:
from typing import Generator, Coroutine


UNCANCEL_DONE = object()

Expand Down Expand Up @@ -49,12 +54,12 @@ def _async_yield(v):


class WrapCoro(
collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
):
def __init__(
self,
coro: collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
coro: Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
context: contextvars.Context,
):
self._coro = coro
Expand Down
2 changes: 1 addition & 1 deletion taskgroup/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def close(self) -> None:
loop.run_until_complete(
loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT) # type: ignore
)
else:
elif sys.version_info >= (3, 9):
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if self._set_event_loop:
Expand Down
14 changes: 9 additions & 5 deletions taskgroup/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

__all__ = ["TaskGroup"]
import sys
import collections.abc
from types import TracebackType
from asyncio import events
from asyncio import exceptions
Expand All @@ -24,6 +23,11 @@
from typing_extensions import Self, TypeAlias, Literal, TypeVar
import contextlib

if sys.version_info >= (3, 9):
from collections.abc import Generator, Coroutine
else:
from typing import Generator, Coroutine


_T = TypeVar("_T")

Expand All @@ -35,14 +39,14 @@

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_TaskYieldType: TypeAlias = Optional[futures.Future[object]]
_TaskYieldType: TypeAlias = "futures.Future[object] | None"

if sys.version_info >= (3, 12):
_TaskCompatibleCoro: TypeAlias = collections.abc.Coroutine[Any, Any, _T_co]
_TaskCompatibleCoro: TypeAlias = Coroutine[Any, Any, _T_co]
else:
_TaskCompatibleCoro: TypeAlias = Union[
collections.abc.Generator[_TaskYieldType, None, _T_co],
collections.abc.Coroutine[Any, Any, _T_co],
Generator[_TaskYieldType, None, _T_co],
Coroutine[Any, Any, _T_co],
]


Expand Down
46 changes: 34 additions & 12 deletions taskgroup/tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import asyncio
import collections.abc
import contextvars
from typing import Any, Optional, Union
from typing import Any, Union, TYPE_CHECKING, Generic
from typing_extensions import TypeAlias, TypeVar, Self
import sys

if sys.version_info >= (3, 9):
from collections.abc import Generator, Coroutine, Awaitable
else:
from typing import Generator, Coroutine, Awaitable

_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
Expand All @@ -15,26 +19,30 @@

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_TaskYieldType: TypeAlias = Optional[asyncio.Future[object]]
_TaskYieldType: TypeAlias = "asyncio.Future[object] | None"

if sys.version_info >= (3, 12):
_TaskCompatibleCoro: TypeAlias = collections.abc.Coroutine[Any, Any, _T_co]
else:
_TaskCompatibleCoro: TypeAlias = Coroutine[Any, Any, _T_co]

Check failure on line 25 in taskgroup/tasks.py

View workflow job for this annotation

GitHub Actions / typecheck

Declaration "_TaskCompatibleCoro" is obscured by a declaration of the same name (reportRedeclaration)
if sys.version_info >= (3, 9):
_TaskCompatibleCoro: TypeAlias = Union[

Check failure on line 27 in taskgroup/tasks.py

View workflow job for this annotation

GitHub Actions / typecheck

"_TaskCompatibleCoro" is declared as a TypeAlias and can be assigned only once (reportRedeclaration)
collections.abc.Generator[_TaskYieldType, None, _T_co],
collections.abc.Coroutine[Any, Any, _T_co],
Generator[_TaskYieldType, None, _T_co],
Coroutine[Any, Any, _T_co],
]
else:
_TaskCompatibleCoro: TypeAlias = (
"Generator[_TaskYieldType, None, _T_co] | Awaitable[_T_co]"
)


class _Interceptor(
collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
):
def __init__(
self,
coro: (
collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
| collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
| Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
),
context: contextvars.Context,
):
Expand All @@ -57,7 +65,21 @@ def close(self) -> None:
super().close()


class Task(asyncio.Task[_T_co]):
if TYPE_CHECKING:

class _Task(asyncio.Task[_T_co]):

Check failure on line 70 in taskgroup/tasks.py

View workflow job for this annotation

GitHub Actions / typecheck

Class declaration "_Task" is obscured by a declaration of the same name (reportRedeclaration)
pass


if sys.version_info >= (3, 8):

class _Task(asyncio.Task, Generic[_T_co]):
pass
else:
_Task = asyncio.Task


class Task(_Task[_T_co]):
def __init__(
self, coro: _TaskCompatibleCoro[_T_co], *args, context=None, **kwargs
) -> None:
Expand Down

0 comments on commit 14e0a83

Please sign in to comment.