Skip to content

Commit

Permalink
Refactor WSMessage to use tagged unions (#7319)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Oct 26, 2024
1 parent 7982b59 commit 5f4c052
Show file tree
Hide file tree
Showing 19 changed files with 195 additions and 99 deletions.
1 change: 1 addition & 0 deletions CHANGES/7319.breaking.rst
12 changes: 12 additions & 0 deletions CHANGES/7319.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Changed ``WSMessage`` to a tagged union of ``NamedTuple`` -- by :user:`Dreamsorcerer`.

This change allows type checkers to know the precise type of ``data``
after checking the ``type`` attribute.

If accessing messages by tuple indexes, the order has now changed.
Code such as:
``typ, data, extra = ws_message``
will need to be changed to:
``data, extra, typ = ws_message``

No changes are needed if accessing by attribute name.
4 changes: 2 additions & 2 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1457,9 +1457,9 @@ Misc

- Rename aiohttp.websocket to aiohttp._ws_impl

- Rename aiohttp.MsgType tp aiohttp.WSMsgType
- Rename ``aiohttp.MsgType`` to ``aiohttp.WSMsgType``

- Introduce aiohttp.WSMessage officially
- Introduce ``aiohttp.WSMessage`` officially

- Rename Message -> WSMessage

Expand Down
16 changes: 8 additions & 8 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import sys
from types import TracebackType
from typing import Any, Final, Optional, Type, cast
from typing import Any, Final, Optional, Type

from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
Expand All @@ -17,7 +17,7 @@
WSMessage,
WSMsgType,
)
from .http_websocket import WebSocketWriter # WSMessage
from .http_websocket import WebSocketWriter, WSMessageError
from .streams import EofStream, FlowControlDataQueue
from .typedefs import (
DEFAULT_JSON_DECODER,
Expand Down Expand Up @@ -173,7 +173,7 @@ def _handle_ping_pong_exception(self, exc: BaseException) -> None:
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
self._reader.feed_data(WSMessageError(data=exc, extra=None))

def _set_closed(self) -> None:
"""Set the connection to closed.
Expand Down Expand Up @@ -342,7 +342,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
return WS_CLOSED_MESSAGE
except ClientError:
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
Expand All @@ -351,13 +351,13 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
return WSMessageError(data=exc)
except Exception as exc:
self._exception = exc
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
return WSMessageError(data=exc)

if msg.type is WSMsgType.CLOSE:
self._set_closing()
Expand All @@ -379,13 +379,13 @@ async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str")
return cast(str, msg.data)
return msg.data

async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
return cast(bytes, msg.data)
return msg.data

async def receive_json(
self,
Expand Down
114 changes: 89 additions & 25 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Callable,
Final,
List,
Literal,
NamedTuple,
Optional,
Pattern,
Expand Down Expand Up @@ -110,26 +111,86 @@ class WSMsgType(IntEnum):
MASK_LEN: Final[int] = 4


class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]
class WSMessageContinuation(NamedTuple):
data: bytes
extra: Optional[str] = None
type: Literal[WSMsgType.CONTINUATION] = WSMsgType.CONTINUATION

def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.

.. versionadded:: 0.22
"""
class WSMessageText(NamedTuple):
data: str
extra: Optional[str] = None
type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT

def json(
self, *, loads: Callable[[Union[str, bytes, bytearray]], Any] = json.loads
) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


class WSMessageBinary(NamedTuple):
data: bytes
extra: Optional[str] = None
type: Literal[WSMsgType.BINARY] = WSMsgType.BINARY

def json(
self, *, loads: Callable[[Union[str, bytes, bytearray]], Any] = json.loads
) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


# Constructing the tuple directly to avoid the overhead of
# the lambda and arg processing since NamedTuples are constructed
# with a run time built lambda
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
class WSMessagePing(NamedTuple):
data: bytes
extra: Optional[str] = None
type: Literal[WSMsgType.PING] = WSMsgType.PING


class WSMessagePong(NamedTuple):
data: bytes
extra: Optional[str] = None
type: Literal[WSMsgType.PONG] = WSMsgType.PONG


class WSMessageClose(NamedTuple):
data: int
extra: Optional[str] = None
type: Literal[WSMsgType.CLOSE] = WSMsgType.CLOSE


class WSMessageClosing(NamedTuple):
data: None = None
extra: Optional[str] = None
type: Literal[WSMsgType.CLOSING] = WSMsgType.CLOSING


class WSMessageClosed(NamedTuple):
data: None = None
extra: Optional[str] = None
type: Literal[WSMsgType.CLOSED] = WSMsgType.CLOSED


class WSMessageError(NamedTuple):
data: BaseException
extra: Optional[str] = None
type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR


WSMessage = Union[
WSMessageContinuation,
WSMessageText,
WSMessageBinary,
WSMessagePing,
WSMessagePong,
WSMessageClose,
WSMessageClosing,
WSMessageClosed,
WSMessageError,
]

WS_CLOSED_MESSAGE = WSMessageClosed()
WS_CLOSING_MESSAGE = WSMessageClosing()


class WebSocketError(Exception):
Expand Down Expand Up @@ -327,6 +388,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return False, b""

def _feed_data(self, data: bytes) -> None:
msg: WSMessage
for fin, opcode, payload, compressed in self.parse_frame(data):
if opcode in MESSAGE_TYPES_WITH_CONTENT:
# load text/binary
Expand Down Expand Up @@ -406,13 +468,17 @@ def _feed_data(self, data: bytes) -> None:
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# tuple.__new__ is used to avoid the overhead of the lambda
msg = tuple.__new__(WSMessage, (WSMsgType.TEXT, text, ""))
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = tuple.__new__(WSMessageText, (text, "", WSMsgType.TEXT))
self.queue.feed_data(msg)
continue

# tuple.__new__ is used to avoid the overhead of the lambda
msg = tuple.__new__(WSMessage, (WSMsgType.BINARY, payload_merged, ""))
msg = tuple.__new__(
WSMessageBinary, (payload_merged, "", WSMsgType.BINARY)
)
self.queue.feed_data(msg)
elif opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
Expand All @@ -428,25 +494,23 @@ def _feed_data(self, data: bytes) -> None:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = tuple.__new__(
WSMessage, (WSMsgType.CLOSE, close_code, close_message)
)
msg = WSMessageClose(data=close_code, extra=close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = tuple.__new__(WSMessage, (WSMsgType.CLOSE, 0, ""))
msg = WSMessageClose(data=0, extra="")

self.queue.feed_data(msg)

elif opcode == WSMsgType.PING:
msg = tuple.__new__(WSMessage, (WSMsgType.PING, payload, ""))
msg = WSMessagePing(data=payload, extra="")
self.queue.feed_data(msg)

elif opcode == WSMsgType.PONG:
msg = tuple.__new__(WSMessage, (WSMsgType.PONG, payload, ""))
msg = WSMessagePong(data=payload, extra="")
self.queue.feed_data(msg)

else:
Expand Down
21 changes: 11 additions & 10 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
import json
import sys
from typing import Any, Final, Iterable, Optional, Tuple, cast
from typing import Any, Final, Iterable, Optional, Tuple

from multidict import CIMultiDict

Expand All @@ -21,10 +21,11 @@
WebSocketWriter,
WSCloseCode,
WSMessage,
WSMsgType as WSMsgType,
WSMsgType,
ws_ext_gen,
ws_ext_parse,
)
from .http_websocket import WSMessageError
from .log import ws_logger
from .streams import EofStream, FlowControlDataQueue
from .typedefs import JSONDecoder, JSONEncoder
Expand Down Expand Up @@ -210,7 +211,7 @@ def _handle_ping_pong_exception(self, exc: BaseException) -> None:
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = exc
if self._waiting and not self._closing and self._reader is not None:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
self._reader.feed_data(WSMessageError(data=exc, extra=None))

def _set_closed(self) -> None:
"""Set the connection to closed.
Expand Down Expand Up @@ -505,13 +506,13 @@ async def close(
self._exception = asyncio.TimeoutError()
return True

def _set_closing(self, code: WSCloseCode) -> None:
def _set_closing(self, code: int) -> None:
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code
self._cancel_heartbeat()

def _set_code_close_transport(self, code: WSCloseCode) -> None:
def _set_code_close_transport(self, code: int) -> None:
"""Set the close code and close the transport."""
self._close_code = code
self._close_transport()
Expand Down Expand Up @@ -562,16 +563,16 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
return WSMessageError(data=exc)
except Exception as exc:
self._exception = exc
self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
return WSMessageError(data=exc)

if msg.type is WSMsgType.CLOSE:
self._set_closing(msg.data)
Expand Down Expand Up @@ -600,13 +601,13 @@ async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg.type, msg.data
)
)
return cast(str, msg.data)
return msg.data

async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
return cast(bytes, msg.data)
return msg.data

async def receive_json(
self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
Expand Down
2 changes: 0 additions & 2 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,6 @@ wildcard
Workflow
ws
wsgi
WSMessage
WSMsgType
wss
www
xxx
Expand Down
Loading

0 comments on commit 5f4c052

Please sign in to comment.