Skip to content

Support Trio via AnyIO #3568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ['async-timeout>=4.0.3; python_full_version<"3.11.3"']
dependencies = ['anyio>=4,<4.6']

[project.optional-dependencies]
hiredis = [
Expand Down
24 changes: 11 additions & 13 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import List, Optional, Union

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
import anyio
from anyio import ClosedResourceError, EndOfStream, IncompleteRead
from anyio.streams.buffered import BufferedByteReceiveStream

from ..exceptions import (
AskError,
Expand Down Expand Up @@ -146,7 +143,7 @@ class AsyncBaseParser(BaseParser):
__slots__ = "_stream", "_read_size"

def __init__(self, socket_read_size: int):
self._stream: Optional[StreamReader] = None
self._stream: Optional[BufferedByteReceiveStream] = None
self._read_size = socket_read_size

async def can_read_destructive(self) -> bool:
Expand Down Expand Up @@ -190,12 +187,13 @@ def on_disconnect(self):
async def can_read_destructive(self) -> bool:
if not self._connected:
raise RedisError("Buffer is closed.")
if self._buffer:
if self._buffer or self._stream.buffer:
return True
try:
async with async_timeout(0):
return self._stream.at_eof()
except TimeoutError:
with anyio.fail_after(0):
await self._stream.receive(0)
return True
except (EndOfStream, TimeoutError, ClosedResourceError):
Comment on lines +193 to +196
Copy link
Author

@thearchitector thearchitector Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whether this pattern does what is intended pending the test suite

return False

async def _read(self, length: int) -> bytes:
Expand All @@ -210,8 +208,8 @@ async def _read(self, length: int) -> bytes:
else:
tail = self._buffer[self._pos :]
try:
data = await self._stream.readexactly(want - len(tail))
except IncompleteReadError as error:
data = await self._stream.receive_exactly(want - len(tail))
except IncompleteRead as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
result = (tail + data)[:-2]
self._chunks.append(data)
Expand Down
19 changes: 8 additions & 11 deletions redis/_parsers/hiredis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import asyncio
import socket
import sys
from typing import Callable, List, Optional, TypedDict, Union

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
import anyio
from anyio import EndOfStream

from ..exceptions import ConnectionError, InvalidResponse, RedisError
from ..typing import EncodableT
Expand Down Expand Up @@ -180,15 +176,16 @@ async def can_read_destructive(self):
if self._reader.gets() is not NOT_ENOUGH_DATA:
return True
try:
async with async_timeout(0):
with anyio.fail_after(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
except TimeoutError:
return False

async def read_from_socket(self):
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
try:
buffer = await self._stream.receive(self._read_size)
except EndOfStream as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
Expand Down
33 changes: 18 additions & 15 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
cast,
)

import anyio
import sniffio

from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
Expand Down Expand Up @@ -365,7 +368,7 @@ def __init__(
# If using a single connection client, we need to lock creation-of and use-of
# the client in order to avoid race conditions such as using asyncio.gather
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()
self._single_conn_lock = anyio.Lock()

def __repr__(self):
return (
Expand Down Expand Up @@ -471,7 +474,7 @@ async def transaction(
return func_value if value_from_callable else exec_value
except WatchError:
if watch_delay is not None and watch_delay > 0:
await asyncio.sleep(watch_delay)
await anyio.sleep(watch_delay)
continue

def lock(
Expand Down Expand Up @@ -587,12 +590,16 @@ def __del__(
self,
_warn: Any = warnings.warn,
_grl: Any = asyncio.get_running_loop,
_sniff: Any = sniffio.current_async_library,
) -> None:
if hasattr(self, "connection") and (self.connection is not None):
_warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
try:
context = {"client": self, "message": self._DEL_MESSAGE}
_grl().call_exception_handler(context)
# trio does not have a concept of a global exception handler since tasks
# are intended to be managed within specific structured contexts
if _sniff() == "asyncio":
context = {"client": self, "message": self._DEL_MESSAGE}
_grl().call_exception_handler(context)
except RuntimeError:
pass
self.connection._close()
Expand Down Expand Up @@ -833,7 +840,7 @@ def __init__(
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self._lock = asyncio.Lock()
self._lock = anyio.Lock()

async def __aenter__(self):
return self
Expand Down Expand Up @@ -991,10 +998,7 @@ async def check_health(self):
"did you forget to call subscribe() or psubscribe()?"
)

if (
conn.health_check_interval
and asyncio.get_running_loop().time() > conn.next_health_check
):
if conn.health_check_interval and anyio.current_time() > conn.next_health_check:
await conn.send_command(
"PING", self.HEALTH_CHECK_MESSAGE, check_health=False
)
Expand Down Expand Up @@ -1185,12 +1189,11 @@ async def run(

This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:
your async backend's task creation and cancellation pattern (e.g. ``asyncio.create_task`` and ``task.cancel``).

>>> task = asyncio.create_task(pubsub.run())

To shut it down, use asyncio cancellation:
Example:

>>> task = asyncio.create_task(pubsub.run())
>>> task.cancel()
>>> await task
"""
Expand All @@ -1207,7 +1210,7 @@ async def run(
await self.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
except asyncio.CancelledError:
except anyio.get_cancelled_exc_class():
raise
except BaseException as e:
if exception_handler is None:
Expand All @@ -1217,7 +1220,7 @@ async def run(
await res
# Ensure that other tasks on the event loop get a chance to run
# if we didn't have to block for I/O anywhere.
await asyncio.sleep(0)
await anyio.lowlevel.checkpoint()


class PubsubWorkerExceptionHandler(Protocol):
Expand Down
Loading