diff --git a/stdlib/@tests/test_cases/check_contextlib.py b/stdlib/@tests/test_cases/check_contextlib.py index 648661bca856..63f8b27e8454 100644 --- a/stdlib/@tests/test_cases/check_contextlib.py +++ b/stdlib/@tests/test_cases/check_contextlib.py @@ -1,9 +1,18 @@ from __future__ import annotations -from contextlib import ExitStack +from contextlib import AbstractContextManager, ExitStack from typing_extensions import assert_type +class CM1(AbstractContextManager): + def __exit__(self, *args) -> None: + return None + + +with CM1() as cm1: + assert_type(cm1, CM1) + + # See issue #7961 class Thing(ExitStack): pass diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index 4663b448c79c..d6d993265fce 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -31,6 +31,7 @@ if sys.version_info >= (3, 11): _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) _T_io = TypeVar("_T_io", bound=IO[str] | None) +_EnterT_co = TypeVar("_EnterT_co", covariant=True, default=Self) _ExitT_co = TypeVar("_ExitT_co", covariant=True, bound=bool | None, default=bool | None) _F = TypeVar("_F", bound=Callable[..., Any]) _G_co = TypeVar("_G_co", bound=Generator[Any, Any, Any] | AsyncGenerator[Any, Any], covariant=True) @@ -46,8 +47,8 @@ _CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any, Any] | _ExitFunc) # At runtime it inherits from ABC and is not a Protocol, but it is on the # allowlist for use as a Protocol. @runtime_checkable -class AbstractContextManager(ABC, Protocol[_T_co, _ExitT_co]): # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] - def __enter__(self) -> _T_co: ... +class AbstractContextManager(ABC, Protocol[_EnterT_co, _ExitT_co]): # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] + def __enter__(self) -> _EnterT_co: ... @abstractmethod def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, / @@ -57,8 +58,8 @@ class AbstractContextManager(ABC, Protocol[_T_co, _ExitT_co]): # type: ignore[m # At runtime it inherits from ABC and is not a Protocol, but it is on the # allowlist for use as a Protocol. @runtime_checkable -class AbstractAsyncContextManager(ABC, Protocol[_T_co, _ExitT_co]): # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] - async def __aenter__(self) -> _T_co: ... +class AbstractAsyncContextManager(ABC, Protocol[_EnterT_co, _ExitT_co]): # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] + async def __aenter__(self) -> _EnterT_co: ... @abstractmethod async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /