diff --git a/pyrightconfig.stricter.json b/pyrightconfig.stricter.json index 9acd48158ca0..f08f827fff3a 100644 --- a/pyrightconfig.stricter.json +++ b/pyrightconfig.stricter.json @@ -31,6 +31,7 @@ "stubs/boltons", "stubs/braintree", "stubs/cffi", + "stubs/channels", "stubs/dateparser", "stubs/defusedxml", "stubs/docker", diff --git a/stubs/channels/@tests/django_settings.py b/stubs/channels/@tests/django_settings.py new file mode 100644 index 000000000000..2be16834be19 --- /dev/null +++ b/stubs/channels/@tests/django_settings.py @@ -0,0 +1,12 @@ +SECRET_KEY = "1" + +INSTALLED_APPS = ( + "django.contrib.contenttypes", + "django.contrib.sites", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.admin.apps.SimpleAdminConfig", + "django.contrib.staticfiles", + "django.contrib.auth", + "channels", +) diff --git a/stubs/channels/@tests/stubtest_allowlist.txt b/stubs/channels/@tests/stubtest_allowlist.txt new file mode 100644 index 000000000000..c5c5761b73fe --- /dev/null +++ b/stubs/channels/@tests/stubtest_allowlist.txt @@ -0,0 +1,15 @@ +# channels.auth.UserLazyObject metaclass is mismatch +channels.auth.UserLazyObject + +# these one need to be exclude due to mypy error: * is not present at runtime +channels.auth.UserLazyObject.DoesNotExist +channels.auth.UserLazyObject.MultipleObjectsReturned +channels.auth.UserLazyObject@AnnotatedWith + +# database_sync_to_async is implemented as a class instance but stubbed as a function +# for better type inference when used as decorator/function +channels.db.database_sync_to_async + +# Set to None on class, but initialized to non-None value in __init__ +channels.generic.websocket.WebsocketConsumer.groups +channels.generic.websocket.AsyncWebsocketConsumer.groups diff --git a/stubs/channels/METADATA.toml b/stubs/channels/METADATA.toml new file mode 100644 index 000000000000..a5df088dbbeb --- /dev/null +++ b/stubs/channels/METADATA.toml @@ -0,0 +1,8 @@ +version = "4.2.*" +upstream_repository = "https://github.com/django/channels" +requires = ["django-stubs>=4.2,<5.3", "asgiref"] + +[tool.stubtest] +mypy_plugins = ['mypy_django_plugin.main'] +mypy_plugins_config = {"django-stubs" = {"django_settings_module" = "@tests.django_settings"}} +stubtest_requirements = ["daphne"] diff --git a/stubs/channels/channels/__init__.pyi b/stubs/channels/channels/__init__.pyi new file mode 100644 index 000000000000..561199954632 --- /dev/null +++ b/stubs/channels/channels/__init__.pyi @@ -0,0 +1,4 @@ +from typing import Final + +__version__: Final[str] +DEFAULT_CHANNEL_LAYER: Final[str] diff --git a/stubs/channels/channels/apps.pyi b/stubs/channels/channels/apps.pyi new file mode 100644 index 000000000000..ad15a21b6961 --- /dev/null +++ b/stubs/channels/channels/apps.pyi @@ -0,0 +1,7 @@ +from typing import Final + +from django.apps import AppConfig + +class ChannelsConfig(AppConfig): + name: Final = "channels" + verbose_name: str = "Channels" diff --git a/stubs/channels/channels/auth.pyi b/stubs/channels/channels/auth.pyi new file mode 100644 index 000000000000..b8594667e059 --- /dev/null +++ b/stubs/channels/channels/auth.pyi @@ -0,0 +1,26 @@ +from asgiref.typing import ASGIReceiveCallable, ASGISendCallable +from channels.middleware import BaseMiddleware +from django.contrib.auth.backends import BaseBackend +from django.contrib.auth.base_user import AbstractBaseUser +from django.contrib.auth.models import AnonymousUser +from django.utils.functional import LazyObject + +from .consumer import _ChannelScope +from .utils import _ChannelApplication + +async def get_user(scope: _ChannelScope) -> AbstractBaseUser | AnonymousUser: ... +async def login(scope: _ChannelScope, user: AbstractBaseUser, backend: BaseBackend | None = None) -> None: ... +async def logout(scope: _ChannelScope) -> None: ... + +# Inherits AbstractBaseUser to improve autocomplete and show this is a lazy proxy for a user. +# At runtime, it's just a LazyObject that wraps the actual user instance. +class UserLazyObject(AbstractBaseUser, LazyObject): ... + +class AuthMiddleware(BaseMiddleware): + def populate_scope(self, scope: _ChannelScope) -> None: ... + async def resolve_scope(self, scope: _ChannelScope) -> None: ... + async def __call__( + self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> _ChannelApplication: ... + +def AuthMiddlewareStack(inner: _ChannelApplication) -> _ChannelApplication: ... diff --git a/stubs/channels/channels/consumer.pyi b/stubs/channels/channels/consumer.pyi new file mode 100644 index 000000000000..66a1ba0df89d --- /dev/null +++ b/stubs/channels/channels/consumer.pyi @@ -0,0 +1,75 @@ +from collections.abc import Awaitable +from typing import Any, ClassVar, Protocol, TypedDict, type_check_only + +from asgiref.typing import ASGIReceiveCallable, ASGISendCallable, Scope, WebSocketScope +from channels.auth import UserLazyObject +from channels.db import database_sync_to_async +from channels.layers import BaseChannelLayer +from django.contrib.sessions.backends.base import SessionBase +from django.utils.functional import LazyObject + +# _LazySession is a LazyObject that wraps a SessionBase instance. +# We subclass both for type checking purposes to expose SessionBase attributes, +# and suppress mypy's "misc" error with `# type: ignore[misc]`. +@type_check_only +class _LazySession(SessionBase, LazyObject): # type: ignore[misc] + _wrapped: SessionBase + +@type_check_only +class _URLRoute(TypedDict): + # Values extracted from Django's URLPattern matching, + # passed through ASGI scope routing. + # `args` and `kwargs` are the result of pattern matching against the URL path. + args: tuple[Any, ...] + kwargs: dict[str, Any] + +# Channel Scope definition +@type_check_only +class _ChannelScope(WebSocketScope, total=False): + # Channels specific + channel: str + url_route: _URLRoute + path_remaining: str + + # Auth specific + cookies: dict[str, str] + session: _LazySession + user: UserLazyObject | None + +# Accepts any ASGI message dict with a required "type" key (str), +# but allows additional arbitrary keys for flexibility. +def get_handler_name(message: dict[str, Any]) -> str: ... +@type_check_only +class _ASGIApplicationProtocol(Protocol): + consumer_class: AsyncConsumer + + # Accepts any initialization kwargs passed to the consumer class. + # Typed as `Any` to allow flexibility in subclass-specific arguments. + consumer_initkwargs: Any + + def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> Awaitable[None]: ... + +class AsyncConsumer: + channel_layer_alias: ClassVar[str] + + scope: _ChannelScope + channel_layer: BaseChannelLayer + channel_name: str + channel_receive: ASGIReceiveCallable + base_send: ASGISendCallable + + async def __call__(self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... + async def dispatch(self, message: dict[str, Any]) -> None: ... + async def send(self, message: dict[str, Any]) -> None: ... + + # initkwargs will be used to instantiate the consumer instance. + @classmethod + def as_asgi(cls, **initkwargs: Any) -> _ASGIApplicationProtocol: ... + +class SyncConsumer(AsyncConsumer): + + # Since we're overriding asynchronous methods with synchronous ones, + # we need to use `# type: ignore[override]` to suppress mypy errors. + @database_sync_to_async + def dispatch(self, message: dict[str, Any]) -> None: ... # type: ignore[override] + def send(self, message: dict[str, Any]) -> None: ... # type: ignore[override] diff --git a/stubs/channels/channels/db.pyi b/stubs/channels/channels/db.pyi new file mode 100644 index 000000000000..c0147a05b03d --- /dev/null +++ b/stubs/channels/channels/db.pyi @@ -0,0 +1,32 @@ +import asyncio +from _typeshed import OptExcInfo +from asyncio import BaseEventLoop +from collections.abc import Callable, Coroutine +from concurrent.futures import ThreadPoolExecutor +from typing import Any, TypeVar +from typing_extensions import ParamSpec + +from asgiref.sync import SyncToAsync + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +class DatabaseSyncToAsync(SyncToAsync[_P, _R]): + def thread_handler( + self, + loop: BaseEventLoop, + exc_info: OptExcInfo, + task_context: list[asyncio.Task[Any]] | None, + func: Callable[_P, _R], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _R: ... + +# We define `database_sync_to_async` as a function instead of assigning +# `DatabaseSyncToAsync(...)` directly, to preserve both decorator and +# higher-order function behavior with correct type hints. +# A direct assignment would result in incorrect type inference for the wrapped function. +def database_sync_to_async( + func: Callable[_P, _R], thread_sensitive: bool = True, executor: ThreadPoolExecutor | None = None +) -> Callable[_P, Coroutine[Any, Any, _R]]: ... +async def aclose_old_connections() -> None: ... diff --git a/stubs/channels/channels/exceptions.pyi b/stubs/channels/channels/exceptions.pyi new file mode 100644 index 000000000000..eaba1dfaee14 --- /dev/null +++ b/stubs/channels/channels/exceptions.pyi @@ -0,0 +1,8 @@ +class RequestAborted(Exception): ... +class RequestTimeout(RequestAborted): ... +class InvalidChannelLayerError(ValueError): ... +class AcceptConnection(Exception): ... +class DenyConnection(Exception): ... +class ChannelFull(Exception): ... +class MessageTooLarge(Exception): ... +class StopConsumer(Exception): ... diff --git a/stubs/channels/channels/generic/__init__.pyi b/stubs/channels/channels/generic/__init__.pyi new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/stubs/channels/channels/generic/http.pyi b/stubs/channels/channels/generic/http.pyi new file mode 100644 index 000000000000..84e3b1dbff84 --- /dev/null +++ b/stubs/channels/channels/generic/http.pyi @@ -0,0 +1,19 @@ +from _typeshed import Unused +from collections.abc import Iterable +from typing import Any + +from asgiref.typing import HTTPDisconnectEvent, HTTPRequestEvent, HTTPScope +from channels.consumer import AsyncConsumer + +class AsyncHttpConsumer(AsyncConsumer): + body: list[bytes] + scope: HTTPScope # type: ignore[assignment] + + def __init__(self, *args: Unused, **kwargs: Unused) -> None: ... + async def send_headers(self, *, status: int = 200, headers: Iterable[tuple[bytes, bytes]] | None = None) -> None: ... + async def send_body(self, body: bytes, *, more_body: bool = False) -> None: ... + async def send_response(self, status: int, body: bytes, **kwargs: Any) -> None: ... + async def handle(self, body: bytes) -> None: ... + async def disconnect(self) -> None: ... + async def http_request(self, message: HTTPRequestEvent) -> None: ... + async def http_disconnect(self, message: HTTPDisconnectEvent) -> None: ... diff --git a/stubs/channels/channels/generic/websocket.pyi b/stubs/channels/channels/generic/websocket.pyi new file mode 100644 index 000000000000..f19d406e1e5d --- /dev/null +++ b/stubs/channels/channels/generic/websocket.pyi @@ -0,0 +1,61 @@ +from _typeshed import Unused +from typing import Any + +from asgiref.typing import WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent +from channels.consumer import AsyncConsumer, SyncConsumer + +class WebsocketConsumer(SyncConsumer): + groups: list[str] + + def __init__(self, *args: Unused, **kwargs: Unused) -> None: ... + def websocket_connect(self, message: WebSocketConnectEvent) -> None: ... + def connect(self) -> None: ... + def accept(self, subprotocol: str | None = None, headers: list[tuple[str, str]] | None = None) -> None: ... + def websocket_receive(self, message: WebSocketReceiveEvent) -> None: ... + def receive(self, text_data: str | None = None, bytes_data: bytes | None = None) -> None: ... + def send( # type: ignore[override] + self, text_data: str | None = None, bytes_data: bytes | None = None, close: bool = False + ) -> None: ... + def close(self, code: int | bool | None = None, reason: str | None = None) -> None: ... + def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None: ... + def disconnect(self, code: int) -> None: ... + +class JsonWebsocketConsumer(WebsocketConsumer): + def receive(self, text_data: str | None = None, bytes_data: bytes | None = None, **kwargs: Any) -> None: ... + # content is typed as Any to match json.loads() return type - JSON can represent + # various Python types (dict, list, str, int, float, bool, None) + def receive_json(self, content: Any, **kwargs: Any) -> None: ... + # content is typed as Any to match json.dumps() input type - accepts any JSON-serializable object + def send_json(self, content: Any, close: bool = False) -> None: ... + @classmethod + def decode_json(cls, text_data: str) -> Any: ... # Returns Any like json.loads() + @classmethod + def encode_json(cls, content: Any) -> str: ... # Accepts Any like json.dumps() + +class AsyncWebsocketConsumer(AsyncConsumer): + groups: list[str] + + def __init__(self, *args: Unused, **kwargs: Unused) -> None: ... + async def websocket_connect(self, message: WebSocketConnectEvent) -> None: ... + async def connect(self) -> None: ... + async def accept(self, subprotocol: str | None = None, headers: list[tuple[str, str]] | None = None) -> None: ... + async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: ... + async def receive(self, text_data: str | None = None, bytes_data: bytes | None = None) -> None: ... + async def send( # type: ignore[override] + self, text_data: str | None = None, bytes_data: bytes | None = None, close: bool = False + ) -> None: ... + async def close(self, code: int | bool | None = None, reason: str | None = None) -> None: ... + async def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None: ... + async def disconnect(self, code: int) -> None: ... + +class AsyncJsonWebsocketConsumer(AsyncWebsocketConsumer): + async def receive(self, text_data: str | None = None, bytes_data: bytes | None = None, **kwargs: Any) -> None: ... + # content is typed as Any to match json.loads() return type - JSON can represent + # various Python types (dict, list, str, int, float, bool, None) + async def receive_json(self, content: Any, **kwargs: Any) -> None: ... + # content is typed as Any to match json.dumps() input type - accepts any JSON-serializable object + async def send_json(self, content: Any, close: bool = False) -> None: ... + @classmethod + async def decode_json(cls, text_data: str) -> Any: ... # Returns Any like json.loads() + @classmethod + async def encode_json(cls, content: Any) -> str: ... # Accepts Any like json.dumps() diff --git a/stubs/channels/channels/layers.pyi b/stubs/channels/channels/layers.pyi new file mode 100644 index 000000000000..16cf0540bf62 --- /dev/null +++ b/stubs/channels/channels/layers.pyi @@ -0,0 +1,91 @@ +import asyncio +from re import Pattern +from typing import Any, ClassVar, overload +from typing_extensions import TypeAlias, deprecated + +class ChannelLayerManager: + backends: dict[str, BaseChannelLayer] + + def __init__(self) -> None: ... + @property + def configs(self) -> dict[str, Any]: ... + def make_backend(self, name: str) -> BaseChannelLayer: ... + def make_test_backend(self, name: str) -> Any: ... + def __getitem__(self, key: str) -> BaseChannelLayer: ... + def __contains__(self, key: str) -> bool: ... + def set(self, key: str, layer: BaseChannelLayer) -> BaseChannelLayer | None: ... + +_ChannelCapacityPattern: TypeAlias = Pattern[str] | str +_ChannelCapacityDict: TypeAlias = dict[_ChannelCapacityPattern, int] +_CompiledChannelCapacity: TypeAlias = list[tuple[Pattern[str], int]] + +class BaseChannelLayer: + MAX_NAME_LENGTH: ClassVar[int] = 100 + expiry: int + capacity: int + channel_capacity: _ChannelCapacityDict + channel_name_regex: Pattern[str] + group_name_regex: Pattern[str] + invalid_name_error: str + + def __init__(self, expiry: int = 60, capacity: int = 100, channel_capacity: _ChannelCapacityDict | None = None) -> None: ... + def compile_capacities(self, channel_capacity: _ChannelCapacityDict) -> _CompiledChannelCapacity: ... + def get_capacity(self, channel: str) -> int: ... + @overload + def match_type_and_length(self, name: str) -> bool: ... + @overload + def match_type_and_length(self, name: object) -> bool: ... + @overload + def require_valid_channel_name(self, name: str, receive: bool = False) -> bool: ... + @overload + def require_valid_channel_name(self, name: object, receive: bool = False) -> bool: ... + @overload + def require_valid_group_name(self, name: str) -> bool: ... + @overload + def require_valid_group_name(self, name: object) -> bool: ... + @overload + def valid_channel_names(self, names: list[str], receive: bool = False) -> bool: ... + @overload + def valid_channel_names(self, names: list[Any], receive: bool = False) -> bool: ... + def non_local_name(self, name: str) -> str: ... + async def send(self, channel: str, message: dict[str, Any]) -> None: ... + async def receive(self, channel: str) -> dict[str, Any]: ... + async def new_channel(self) -> str: ... + async def flush(self) -> None: ... + async def group_add(self, group: str, channel: str) -> None: ... + async def group_discard(self, group: str, channel: str) -> None: ... + async def group_send(self, group: str, message: dict[str, Any]) -> None: ... + @deprecated("Use require_valid_channel_name instead.") + def valid_channel_name(self, channel_name: str, receive: bool = False) -> bool: ... + @deprecated("Use require_valid_group_name instead.") + def valid_group_name(self, group_name: str) -> bool: ... + +_InMemoryQueueData: TypeAlias = tuple[float, dict[str, Any]] + +class InMemoryChannelLayer(BaseChannelLayer): + channels: dict[str, asyncio.Queue[_InMemoryQueueData]] + groups: dict[str, dict[str, float]] + group_expiry: int + + def __init__( + self, + expiry: int = 60, + group_expiry: int = 86400, + capacity: int = 100, + channel_capacity: _ChannelCapacityDict | None = ..., + ) -> None: ... + + extensions: list[str] + + async def send(self, channel: str, message: dict[str, Any]) -> None: ... + async def receive(self, channel: str) -> dict[str, Any]: ... + async def new_channel(self, prefix: str = "specific.") -> str: ... + async def flush(self) -> None: ... + async def close(self) -> None: ... + async def group_add(self, group: str, channel: str) -> None: ... + async def group_discard(self, group: str, channel: str) -> None: ... + async def group_send(self, group: str, message: dict[str, Any]) -> None: ... + +def get_channel_layer(alias: str = ...) -> BaseChannelLayer | None: ... + +channel_layers: ChannelLayerManager diff --git a/stubs/channels/channels/management/__init__.pyi b/stubs/channels/channels/management/__init__.pyi new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/stubs/channels/channels/management/commands/__init__.pyi b/stubs/channels/channels/management/commands/__init__.pyi new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/stubs/channels/channels/management/commands/runworker.pyi b/stubs/channels/channels/management/commands/runworker.pyi new file mode 100644 index 000000000000..533b94cea88d --- /dev/null +++ b/stubs/channels/channels/management/commands/runworker.pyi @@ -0,0 +1,25 @@ +import logging +from _typeshed import Unused +from argparse import ArgumentParser +from typing import TypedDict, type_check_only + +from channels.layers import BaseChannelLayer +from channels.worker import Worker +from django.core.management.base import BaseCommand + +logger: logging.Logger + +@type_check_only +class _RunWorkerCommandOption(TypedDict): + verbosity: int | None + layer: str + channels: list[str] + +class Command(BaseCommand): + leave_locale_alone: bool = True + worker_class: type[Worker] = ... + verbosity: int + channel_layer: BaseChannelLayer + + def add_arguments(self, parser: ArgumentParser) -> None: ... + def handle(self, *args: Unused, **options: _RunWorkerCommandOption) -> None: ... diff --git a/stubs/channels/channels/middleware.pyi b/stubs/channels/channels/middleware.pyi new file mode 100644 index 000000000000..339ae9218244 --- /dev/null +++ b/stubs/channels/channels/middleware.pyi @@ -0,0 +1,12 @@ +from asgiref.typing import ASGIReceiveCallable, ASGISendCallable + +from .consumer import _ChannelScope +from .utils import _ChannelApplication + +class BaseMiddleware: + inner: _ChannelApplication + + def __init__(self, inner: _ChannelApplication) -> None: ... + async def __call__( + self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> _ChannelApplication: ... diff --git a/stubs/channels/channels/routing.pyi b/stubs/channels/channels/routing.pyi new file mode 100644 index 000000000000..d2a0655d43ce --- /dev/null +++ b/stubs/channels/channels/routing.pyi @@ -0,0 +1,31 @@ +from typing import Any, type_check_only + +from asgiref.typing import ASGIReceiveCallable, ASGISendCallable, Scope +from django.urls.resolvers import URLPattern + +from .consumer import _ASGIApplicationProtocol, _ChannelScope +from .utils import _ChannelApplication + +def get_default_application() -> ProtocolTypeRouter: ... + +class ProtocolTypeRouter: + application_mapping: dict[str, _ChannelApplication] + + def __init__(self, application_mapping: dict[str, Any]) -> None: ... + async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... + +@type_check_only +class _ExtendedURLPattern(URLPattern): + callback: _ASGIApplicationProtocol | URLRouter + +class URLRouter: + routes: list[_ExtendedURLPattern | URLRouter] + + def __init__(self, routes: list[_ExtendedURLPattern | URLRouter]) -> None: ... + async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... + +class ChannelNameRouter: + application_mapping: dict[str, _ChannelApplication] + + def __init__(self, application_mapping: dict[str, _ChannelApplication]) -> None: ... + async def __call__(self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... diff --git a/stubs/channels/channels/security/__init__.pyi b/stubs/channels/channels/security/__init__.pyi new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/stubs/channels/channels/security/websocket.pyi b/stubs/channels/channels/security/websocket.pyi new file mode 100644 index 000000000000..aa76a1b4012e --- /dev/null +++ b/stubs/channels/channels/security/websocket.pyi @@ -0,0 +1,25 @@ +from collections.abc import Iterable +from re import Pattern +from typing import Any +from urllib.parse import ParseResult + +from asgiref.typing import ASGIReceiveCallable, ASGISendCallable +from channels.consumer import _ChannelScope +from channels.generic.websocket import AsyncWebsocketConsumer +from channels.utils import _ChannelApplication + +class OriginValidator: + application: _ChannelApplication + allowed_origins: Iterable[str | Pattern[str]] + + def __init__(self, application: _ChannelApplication, allowed_origins: Iterable[str | Pattern[str]]) -> None: ... + async def __call__(self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> Any: ... + def valid_origin(self, parsed_origin: ParseResult | None) -> bool: ... + def validate_origin(self, parsed_origin: ParseResult | None) -> bool: ... + def match_allowed_origin(self, parsed_origin: ParseResult | None, pattern: str | Pattern[str]) -> bool: ... + def get_origin_port(self, origin: ParseResult | None) -> int | None: ... + +def AllowedHostsOriginValidator(application: _ChannelApplication) -> OriginValidator: ... + +class WebsocketDenier(AsyncWebsocketConsumer): + async def connect(self) -> None: ... diff --git a/stubs/channels/channels/sessions.pyi b/stubs/channels/channels/sessions.pyi new file mode 100644 index 000000000000..04a4eb729535 --- /dev/null +++ b/stubs/channels/channels/sessions.pyi @@ -0,0 +1,56 @@ +import datetime +from collections.abc import Awaitable +from typing import Any + +from asgiref.typing import ASGIReceiveCallable, ASGISendCallable +from channels.consumer import _ChannelScope +from channels.utils import _ChannelApplication +from django.contrib.sessions.backends.base import SessionBase + +class CookieMiddleware: + inner: _ChannelApplication + + def __init__(self, inner: _ChannelApplication) -> None: ... + + # Returns the same type as the provided _ChannelApplication. + async def __call__(self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> Any: ... + @classmethod + def set_cookie( + cls, + message: dict[str, Any], + key: str, + value: str = "", + max_age: int | None = None, + expires: str | datetime.datetime | None = None, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: str = "lax", + ) -> None: ... + @classmethod + def delete_cookie(cls, message: dict[str, Any], key: str, path: str = "/", domain: str | None = None) -> None: ... + +class InstanceSessionWrapper: + save_message_types: list[str] + cookie_response_message_types: list[str] + cookie_name: str + session_store: SessionBase + scope: _ChannelScope + activated: bool + real_send: ASGISendCallable + + def __init__(self, scope: _ChannelScope, send: ASGISendCallable) -> None: ... + async def resolve_session(self) -> None: ... + async def send(self, message: dict[str, Any]) -> Awaitable[None]: ... + async def save_session(self) -> None: ... + +class SessionMiddleware: + inner: _ChannelApplication + + def __init__(self, inner: _ChannelApplication) -> None: ... + async def __call__( + self, scope: _ChannelScope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> _ChannelApplication: ... + +def SessionMiddlewareStack(inner: _ChannelApplication) -> _ChannelApplication: ... diff --git a/stubs/channels/channels/testing/__init__.pyi b/stubs/channels/channels/testing/__init__.pyi new file mode 100644 index 000000000000..1cb75a0bf6dd --- /dev/null +++ b/stubs/channels/channels/testing/__init__.pyi @@ -0,0 +1,6 @@ +from .application import ApplicationCommunicator +from .http import HttpCommunicator +from .live import ChannelsLiveServerTestCase +from .websocket import WebsocketCommunicator + +__all__ = ["ApplicationCommunicator", "HttpCommunicator", "ChannelsLiveServerTestCase", "WebsocketCommunicator"] diff --git a/stubs/channels/channels/testing/application.pyi b/stubs/channels/channels/testing/application.pyi new file mode 100644 index 000000000000..db3f567c1d61 --- /dev/null +++ b/stubs/channels/channels/testing/application.pyi @@ -0,0 +1,21 @@ +from typing import Any + +from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator + +def no_op() -> None: ... + +class ApplicationCommunicator(BaseApplicationCommunicator): + # ASGI messages are dictionaries with a "type" key and protocol-specific fields. + # Dictionary values can be strings, bytes, lists, or other types depending on the protocol: + # - HTTP: {"type": "http.request", "body": b"request data", "headers": [...], ...} + # - WebSocket: {"type": "websocket.receive", "bytes": b"binary data"} or {"text": "string"} + # - Custom protocols: Application-specific message dictionaries + async def send_input(self, message: dict[str, Any]) -> None: ... + async def receive_output(self, timeout: float = 1) -> dict[str, Any]: ... + + # The following methods are not present in the original source code, + # but are commonly used in practice. Since the base package doesn't + # provide type hints for them, they are added here to improve type correctness. + async def receive_nothing(self, timeout: float = 0.1, interval: float = 0.01) -> bool: ... + async def wait(self, timeout: float = 1) -> None: ... + def stop(self, exceptions: bool = True) -> None: ... diff --git a/stubs/channels/channels/testing/http.pyi b/stubs/channels/channels/testing/http.pyi new file mode 100644 index 000000000000..6eb6650036c8 --- /dev/null +++ b/stubs/channels/channels/testing/http.pyi @@ -0,0 +1,41 @@ +from collections.abc import Iterable +from typing import Literal, TypedDict, type_check_only + +from channels.testing.application import ApplicationCommunicator +from channels.utils import _ChannelApplication + +# HTTP test-specific response type +@type_check_only +class _HTTPTestResponse(TypedDict, total=False): + status: int + headers: Iterable[tuple[bytes, bytes]] + body: bytes + +@type_check_only +class _HTTPTestScope(TypedDict, total=False): + type: Literal["http"] + http_version: str + method: str + scheme: str + path: str + raw_path: bytes + query_string: bytes + root_path: str + headers: Iterable[tuple[bytes, bytes]] | None + client: tuple[str, int] | None + server: tuple[str, int | None] | None + +class HttpCommunicator(ApplicationCommunicator): + scope: _HTTPTestScope + body: bytes + sent_request: bool + + def __init__( + self, + application: _ChannelApplication, + method: str, + path: str, + body: bytes = b"", + headers: Iterable[tuple[bytes, bytes]] | None = None, + ) -> None: ... + async def get_response(self, timeout: float = 1) -> _HTTPTestResponse: ... diff --git a/stubs/channels/channels/testing/live.pyi b/stubs/channels/channels/testing/live.pyi new file mode 100644 index 000000000000..769604458d54 --- /dev/null +++ b/stubs/channels/channels/testing/live.pyi @@ -0,0 +1,25 @@ +from collections.abc import Callable +from typing import Any, ClassVar +from typing_extensions import TypeAlias + +from channels.routing import ProtocolTypeRouter +from channels.utils import _ChannelApplication +from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler +from django.test.testcases import TransactionTestCase + +DaphneProcess: TypeAlias = Any # TODO: temporary hack for daphne.testing.DaphneProcess; remove once daphne provides types + +_StaticWrapper: TypeAlias = Callable[[ProtocolTypeRouter], _ChannelApplication] + +def make_application(*, static_wrapper: _StaticWrapper | None) -> Any: ... + +class ChannelsLiveServerTestCase(TransactionTestCase): + host: ClassVar[str] = "localhost" + ProtocolServerProcess: ClassVar[type[DaphneProcess]] = ... + static_wrapper: ClassVar[type[ASGIStaticFilesHandler]] = ... + serve_static: ClassVar[bool] = True + + @property + def live_server_url(self) -> str: ... + @property + def live_server_ws_url(self) -> str: ... diff --git a/stubs/channels/channels/testing/websocket.pyi b/stubs/channels/channels/testing/websocket.pyi new file mode 100644 index 000000000000..f303e90e5c39 --- /dev/null +++ b/stubs/channels/channels/testing/websocket.pyi @@ -0,0 +1,54 @@ +from collections.abc import Iterable +from typing import Any, Literal, TypedDict, overload, type_check_only +from typing_extensions import NotRequired, TypeAlias + +from asgiref.typing import ASGIVersions +from channels.testing.application import ApplicationCommunicator +from channels.utils import _ChannelApplication + +@type_check_only +class _WebsocketTestScope(TypedDict, total=False): + spec_version: int + type: Literal["websocket"] + asgi: ASGIVersions + http_version: str + scheme: str + path: str + raw_path: bytes + query_string: bytes + root_path: str + headers: Iterable[tuple[bytes, bytes]] | None + client: tuple[str, int] | None + server: tuple[str, int | None] | None + subprotocols: Iterable[str] | None + state: NotRequired[dict[str, Any]] + extensions: dict[str, dict[object, object]] | None + +_Connected: TypeAlias = bool +_CloseCodeOrAcceptSubProtocol: TypeAlias = int | str | None +_WebsocketConnectResponse: TypeAlias = tuple[_Connected, _CloseCodeOrAcceptSubProtocol] + +class WebsocketCommunicator(ApplicationCommunicator): + scope: _WebsocketTestScope + response_headers: list[tuple[bytes, bytes]] | None + + def __init__( + self, + application: _ChannelApplication, + path: str, + headers: Iterable[tuple[bytes, bytes]] | None = None, + subprotocols: Iterable[str] | None = None, + spec_version: int | None = None, + ) -> None: ... + async def connect(self, timeout: float = 1) -> _WebsocketConnectResponse: ... + async def send_to(self, text_data: str | None = None, bytes_data: bytes | None = None) -> None: ... + async def receive_from(self, timeout: float = 1) -> str | bytes: ... + + # These overloads reflect common usage, where users typically send and receive `dict[str, Any]`. + # The base case allows `Any` to support broader `json.dumps` / `json.loads` compatibility. + @overload + async def send_json_to(self, data: dict[str, Any]) -> None: ... + @overload + async def send_json_to(self, data: Any) -> None: ... + async def receive_json_from(self, timeout: float = 1) -> Any: ... + async def disconnect(self, code: int = 1000, timeout: float = 1) -> None: ... diff --git a/stubs/channels/channels/utils.pyi b/stubs/channels/channels/utils.pyi new file mode 100644 index 000000000000..0e0818abbbb6 --- /dev/null +++ b/stubs/channels/channels/utils.pyi @@ -0,0 +1,20 @@ +from collections.abc import Awaitable, Callable +from typing import Any, Protocol, type_check_only +from typing_extensions import TypeAlias + +from asgiref.typing import ASGIApplication, ASGIReceiveCallable + +def name_that_thing(thing: object) -> str: ... +async def await_many_dispatch( + consumer_callables: list[Callable[[], Awaitable[ASGIReceiveCallable]]], dispatch: Callable[[dict[str, Any]], Awaitable[None]] +) -> None: ... + +# Defines a generic ASGI middleware protocol. +# All arguments are typed as `Any` to maximize compatibility with third-party ASGI middleware +# that may not strictly follow type conventions or use more specific signatures. +@type_check_only +class _MiddlewareProtocol(Protocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + async def __call__(self, scope: Any, receive: Any, send: Any) -> Any: ... + +_ChannelApplication: TypeAlias = _MiddlewareProtocol | ASGIApplication # noqa: Y047 diff --git a/stubs/channels/channels/worker.pyi b/stubs/channels/channels/worker.pyi new file mode 100644 index 000000000000..a20b5feec275 --- /dev/null +++ b/stubs/channels/channels/worker.pyi @@ -0,0 +1,13 @@ +from asgiref.server import StatelessServer +from channels.layers import BaseChannelLayer +from channels.utils import _ChannelApplication + +class Worker(StatelessServer): + channels: list[str] + channel_layer: BaseChannelLayer + + def __init__( + self, application: _ChannelApplication, channels: list[str], channel_layer: BaseChannelLayer, max_applications: int = 1000 + ) -> None: ... + async def handle(self) -> None: ... + async def listener(self, channel: str) -> None: ...