Skip to content

Channels stubs with stubtest #13949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ This has the following keys:
If not specified, stubtest is run only on `linux`.
Only add extra OSes to the test
if there are platform-specific branches in a stubs package.
* `mypy_plugins` (default: `[]`): A list of Python modules to use as mypy plugins
when running stubtest. For example: `mypy_plugins = ["mypy_django_plugin.main"]`
* `mypy_plugins_config` (default: `{}`): A dictionary mapping plugin names to their
configuration dictionaries for use by mypy plugins. For example:
`mypy_plugins_config = {"django-stubs" = {"django_settings_module" = "@tests.django_settings"}}`


`*_dependencies` are usually packages needed to `pip install` the implementation
distribution.
Expand Down
16 changes: 15 additions & 1 deletion lib/ts_utils/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Final, NamedTuple, final
from typing import Annotated, Any, Final, NamedTuple, final
from typing_extensions import TypeGuard

import tomli
Expand Down Expand Up @@ -42,6 +42,10 @@ def _is_list_of_strings(obj: object) -> TypeGuard[list[str]]:
return isinstance(obj, list) and all(isinstance(item, str) for item in obj)


def _is_nested_dict(obj: object) -> TypeGuard[dict[str, dict[str, Any]]]:
return isinstance(obj, dict) and all(isinstance(k, str) and isinstance(v, dict) for k, v in obj.items())


@functools.cache
def _get_oldest_supported_python() -> str:
with PYPROJECT_PATH.open("rb") as config:
Expand Down Expand Up @@ -71,6 +75,8 @@ class StubtestSettings:
ignore_missing_stub: bool
platforms: list[str]
stubtest_requirements: list[str]
mypy_plugins: list[str]
mypy_plugins_config: dict[str, dict[str, Any]]

def system_requirements_for_platform(self, platform: str) -> list[str]:
assert platform in _STUBTEST_PLATFORM_MAPPING, f"Unrecognised platform {platform!r}"
Expand All @@ -93,6 +99,8 @@ def read_stubtest_settings(distribution: str) -> StubtestSettings:
ignore_missing_stub: object = data.get("ignore_missing_stub", False)
specified_platforms: object = data.get("platforms", ["linux"])
stubtest_requirements: object = data.get("stubtest_requirements", [])
mypy_plugins: object = data.get("mypy_plugins", [])
mypy_plugins_config: object = data.get("mypy_plugins_config", {})

assert type(skip) is bool
assert type(ignore_missing_stub) is bool
Expand All @@ -104,6 +112,8 @@ def read_stubtest_settings(distribution: str) -> StubtestSettings:
assert _is_list_of_strings(choco_dependencies)
assert _is_list_of_strings(extras)
assert _is_list_of_strings(stubtest_requirements)
assert _is_list_of_strings(mypy_plugins)
assert _is_nested_dict(mypy_plugins_config)

unrecognised_platforms = set(specified_platforms) - _STUBTEST_PLATFORM_MAPPING.keys()
assert not unrecognised_platforms, f"Unrecognised platforms specified for {distribution!r}: {unrecognised_platforms}"
Expand All @@ -124,6 +134,8 @@ def read_stubtest_settings(distribution: str) -> StubtestSettings:
ignore_missing_stub=ignore_missing_stub,
platforms=specified_platforms,
stubtest_requirements=stubtest_requirements,
mypy_plugins=mypy_plugins,
mypy_plugins_config=mypy_plugins_config,
)


Expand Down Expand Up @@ -179,6 +191,8 @@ def is_obsolete(self) -> bool:
"ignore_missing_stub",
"platforms",
"stubtest_requirements",
"mypy_plugins",
"mypy_plugins_config",
}
}
_DIST_NAME_RE: Final = re.compile(r"^[a-z0-9]([a-z0-9._-]*[a-z0-9])?$", re.IGNORECASE)
Expand Down
17 changes: 15 additions & 2 deletions lib/ts_utils/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import tomli

from ts_utils.metadata import metadata_path
from ts_utils.metadata import StubtestSettings, metadata_path
from ts_utils.utils import NamedTemporaryFile, TemporaryFileWrapper


Expand Down Expand Up @@ -50,14 +50,27 @@ def validate_configuration(section_name: str, mypy_section: dict[str, Any]) -> M


@contextmanager
def temporary_mypy_config_file(configurations: Iterable[MypyDistConf]) -> Generator[TemporaryFileWrapper[str]]:
def temporary_mypy_config_file(
configurations: Iterable[MypyDistConf], stubtest_settings: StubtestSettings | None = None
) -> Generator[TemporaryFileWrapper[str]]:
temp = NamedTemporaryFile("w+")
try:
for dist_conf in configurations:
temp.write(f"[mypy-{dist_conf.module_name}]\n")
for k, v in dist_conf.values.items():
temp.write(f"{k} = {v}\n")
temp.write("[mypy]\n")

if stubtest_settings:
if stubtest_settings.mypy_plugins:
temp.write(f"plugins = {'.'.join(stubtest_settings.mypy_plugins)}\n")

if stubtest_settings.mypy_plugins_config:
for plugin_name, plugin_dict in stubtest_settings.mypy_plugins_config.items():
temp.write(f"[mypy.plugins.{plugin_name}]\n")
for k, v in plugin_dict.items():
temp.write(f"{k} = {v}\n")

temp.flush()
yield temp
finally:
Expand Down
1 change: 1 addition & 0 deletions pyrightconfig.stricter.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"stubs/braintree",
"stubs/caldav",
"stubs/cffi",
"stubs/channels",
"stubs/click-web",
"stubs/corus",
"stubs/dateparser",
Expand Down
12 changes: 12 additions & 0 deletions stubs/channels/@tests/django_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SECRET_KEY = "1"

INSTALLED_APPS = (
"django.contrib.contenttypes",
"django.contrib.sites",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.admin.apps.SimpleAdminConfig",
"django.contrib.staticfiles",
"django.contrib.auth",
"channels",
)
3 changes: 3 additions & 0 deletions stubs/channels/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
channels.auth.UserLazyObject
channels.auth.UserLazyObject.*
channels.db.database_sync_to_async
8 changes: 8 additions & 0 deletions stubs/channels/METADATA.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
version = "4.*"
upstream_repository = "https://github.com/django/channels"
requires = ["django-stubs>=4.2,<5.3", "asgiref"]

[tool.stubtest]
mypy_plugins = ['mypy_django_plugin.main']
mypy_plugins_config = {"django-stubs" = {"django_settings_module" = "@tests.django_settings"}}
stubtest_requirements = ["daphne"]
2 changes: 2 additions & 0 deletions stubs/channels/channels/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__version__: str
DEFAULT_CHANNEL_LAYER: str
5 changes: 5 additions & 0 deletions stubs/channels/channels/apps.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.apps import AppConfig

class ChannelsConfig(AppConfig):
name: str = ...
verbose_name: str = ...
32 changes: 32 additions & 0 deletions stubs/channels/channels/auth.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any

from asgiref.typing import ASGIReceiveCallable, ASGISendCallable
from channels.middleware import BaseMiddleware
from django.contrib.auth.backends import BaseBackend
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from django.utils.functional import LazyObject

from .consumer import _ChannelScope, _LazySession
from .db import database_sync_to_async
from .utils import _ChannelApplication

@database_sync_to_async
def get_user(scope: _ChannelScope) -> AbstractBaseUser | AnonymousUser: ...
@database_sync_to_async
def login(scope: _ChannelScope, user: AbstractBaseUser, backend: BaseBackend | None = ...) -> None: ...
@database_sync_to_async
def logout(scope: _ChannelScope) -> None: ...
def _get_user_session_key(session: _LazySession) -> Any: ...

class UserLazyObject(AbstractBaseUser, LazyObject):
def _setup(self) -> None: ...

class AuthMiddleware(BaseMiddleware):
def populate_scope(self, scope: _ChannelScope) -> None: ...
async def resolve_scope(self, scope: _ChannelScope) -> None: ...
async def __call__(
self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> _ChannelApplication: ...

def AuthMiddlewareStack(inner: _ChannelApplication) -> _ChannelApplication: ...
57 changes: 57 additions & 0 deletions stubs/channels/channels/consumer.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections.abc import Awaitable
from typing import Any, ClassVar, Protocol, type_check_only

from asgiref.typing import ASGIReceiveCallable, ASGISendCallable, Scope, WebSocketScope
from channels.auth import UserLazyObject
from channels.db import database_sync_to_async
from channels.layers import BaseChannelLayer
from django.contrib.sessions.backends.base import SessionBase
from django.utils.functional import LazyObject

@type_check_only
class _LazySession(SessionBase, LazyObject): # type: ignore[misc]
_wrapped: SessionBase

# Base ASGI Scope definition
@type_check_only
class _ChannelScope(WebSocketScope, total=False):
# Channel specific
channel: str
url_route: dict[str, Any]
path_remaining: str

# Auth specific
cookies: dict[str, str]
session: _LazySession
user: UserLazyObject | None

def get_handler_name(message: dict[str, Any]) -> str: ...
@type_check_only
class _ASGIApplicationProtocol(Protocol):
consumer_class: Any
consumer_initkwargs: dict[str, Any]

def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> Awaitable[None]: ...

class AsyncConsumer:
_sync: ClassVar[bool] = ...
channel_layer_alias: ClassVar[str] = ...

scope: _ChannelScope
channel_layer: BaseChannelLayer
channel_name: str
channel_receive: ASGIReceiveCallable
base_send: ASGISendCallable

async def __call__(self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ...
async def dispatch(self, message: dict[str, Any]) -> None: ...
async def send(self, message: dict[str, Any]) -> None: ...
@classmethod
def as_asgi(cls, **initkwargs: Any) -> _ASGIApplicationProtocol: ...

class SyncConsumer(AsyncConsumer):
_sync: ClassVar[bool] = ...

@database_sync_to_async
def dispatch(self, message: dict[str, Any]) -> None: ... # type: ignore[override]
def send(self, message: dict[str, Any]) -> None: ... # type: ignore[override]
15 changes: 15 additions & 0 deletions stubs/channels/channels/db.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from asyncio import BaseEventLoop
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar
from typing_extensions import ParamSpec

from asgiref.sync import SyncToAsync

_P = ParamSpec("_P")
_R = TypeVar("_R")

class DatabaseSyncToAsync(SyncToAsync[_P, _R]):
def thread_handler(self, loop: BaseEventLoop, *args: Any, **kwargs: Any) -> Any: ...

def database_sync_to_async(func: Callable[_P, _R]) -> Callable[_P, Coroutine[Any, Any, _R]]: ...
async def aclose_old_connections() -> None: ...
8 changes: 8 additions & 0 deletions stubs/channels/channels/exceptions.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class RequestAborted(Exception): ...
class RequestTimeout(RequestAborted): ...
class InvalidChannelLayerError(ValueError): ...
class AcceptConnection(Exception): ...
class DenyConnection(Exception): ...
class ChannelFull(Exception): ...
class MessageTooLarge(Exception): ...
class StopConsumer(Exception): ...
Empty file.
19 changes: 19 additions & 0 deletions stubs/channels/channels/generic/http.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections.abc import Iterable
from typing import Any

from asgiref.typing import HTTPDisconnectEvent, HTTPRequestEvent, HTTPScope
from channels.consumer import AsyncConsumer

class AsyncHttpConsumer(AsyncConsumer):
body: list[bytes]
scope: HTTPScope # type: ignore[assignment]

def __init__(self, *args: Any, **kwargs: Any) -> None: ...
async def send_headers(self, *, status: int = ..., headers: Iterable[tuple[bytes, bytes]] | None = ...) -> None: ...
async def send_body(self, body: bytes, *, more_body: bool = ...) -> None: ...
async def send_response(self, status: int, body: bytes, **kwargs: Any) -> None: ...
async def handle(self, body: bytes) -> None: ...
async def disconnect(self) -> None: ...
async def http_request(self, message: HTTPRequestEvent) -> None: ...
async def http_disconnect(self, message: HTTPDisconnectEvent) -> None: ...
async def send(self, message: dict[str, Any]) -> None: ... # type: ignore[override]
65 changes: 65 additions & 0 deletions stubs/channels/channels/generic/websocket.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any

from asgiref.typing import WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent
from channels.consumer import AsyncConsumer, SyncConsumer, _ChannelScope
from channels.layers import BaseChannelLayer

class WebsocketConsumer(SyncConsumer):
groups: list[str] | None
scope: _ChannelScope
channel_name: str
channel_layer: BaseChannelLayer
channel_receive: Any
base_send: Any

def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def websocket_connect(self, message: WebSocketConnectEvent) -> None: ...
def connect(self) -> None: ...
def accept(self, subprotocol: str | None = ..., headers: list[tuple[str, str]] | None = ...) -> None: ...
def websocket_receive(self, message: WebSocketReceiveEvent) -> None: ...
def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ...) -> None: ...
def send( # type: ignore[override]
self, text_data: str | None = ..., bytes_data: bytes | None = ..., close: bool = ...
) -> None: ...
def close(self, code: int | bool | None = ..., reason: str | None = ...) -> None: ...
def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None: ...
def disconnect(self, code: int) -> None: ...

class JsonWebsocketConsumer(WebsocketConsumer):
def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ..., **kwargs: Any) -> None: ...
def receive_json(self, content: Any, **kwargs: Any) -> None: ...
def send_json(self, content: Any, close: bool = ...) -> None: ...
@classmethod
def decode_json(cls, text_data: str) -> Any: ...
@classmethod
def encode_json(cls, content: Any) -> str: ...

class AsyncWebsocketConsumer(AsyncConsumer):
groups: list[str] | None
scope: _ChannelScope
channel_name: str
channel_layer: BaseChannelLayer
channel_receive: Any
base_send: Any

def __init__(self, *args: Any, **kwargs: Any) -> None: ...
async def websocket_connect(self, message: WebSocketConnectEvent) -> None: ...
async def connect(self) -> None: ...
async def accept(self, subprotocol: str | None = ..., headers: list[tuple[str, str]] | None = ...) -> None: ...
async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: ...
async def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ...) -> None: ...
async def send( # type: ignore[override]
self, text_data: str | None = ..., bytes_data: bytes | None = ..., close: bool = ...
) -> None: ...
async def close(self, code: int | bool | None = ..., reason: str | None = ...) -> None: ...
async def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None: ...
async def disconnect(self, code: int) -> None: ...

class AsyncJsonWebsocketConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ..., **kwargs: Any) -> None: ...
async def receive_json(self, content: Any, **kwargs: Any) -> None: ...
async def send_json(self, content: Any, close: bool = ...) -> None: ...
@classmethod
async def decode_json(cls, text_data: str) -> Any: ...
@classmethod
async def encode_json(cls, content: Any) -> str: ...
Loading