Skip to content

Commit b2760e6

Browse files
committed
Support AnyIO
1 parent 30e3189 commit b2760e6

File tree

6 files changed

+384
-385
lines changed

6 files changed

+384
-385
lines changed

pyproject.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ classifiers = [
4040
"Programming Language :: Python :: 3.12",
4141
"Programming Language :: Python :: 3.13",
4242
]
43-
dependencies = ["cffi; implementation_name == 'pypy'"]
43+
dependencies = [
44+
"cffi; implementation_name == 'pypy'",
45+
"anyioutils >=0.4.2"
46+
]
4447
description = "Python bindings for 0MQ"
4548
readme = "README.md"
4649

@@ -144,7 +147,7 @@ search = '__version__: str = "{current_version}"'
144147
[tool.cibuildwheel]
145148
build-verbosity = "1"
146149
free-threaded-support = true
147-
test-requires = ["pytest>=6", "importlib_metadata"]
150+
test-requires = ["pytest>=6", "importlib_metadata", "exceptiongroup;python_version<'3.11'"]
148151
test-command = "pytest -vsx {package}/tools/test_wheel.py"
149152

150153
[tool.cibuildwheel.linux]

tests/test_asyncio.py

+108-112
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
from multiprocessing import Process
1212

1313
import pytest
14+
from anyio import create_task_group, move_on_after, sleep
15+
from anyioutils import CancelledError, create_task
1416
from pytest import mark
1517

1618
import zmq
1719
import zmq.asyncio as zaio
1820

21+
if sys.version_info < (3, 11):
22+
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
23+
24+
25+
pytestmark = pytest.mark.anyio
26+
1927

2028
@pytest.fixture
2129
def Context(event_loop):
@@ -46,23 +54,17 @@ def test_instance_subclass_second(context):
4654
async def test_recv_multipart(context, create_bound_pair):
4755
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
4856
f = b.recv_multipart()
49-
assert not f.done()
5057
await a.send(b"hi")
51-
recvd = await f
52-
assert recvd == [b"hi"]
58+
assert await f == [b"hi"]
5359

5460

5561
async def test_recv(create_bound_pair):
5662
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
5763
f1 = b.recv()
5864
f2 = b.recv()
59-
assert not f1.done()
60-
assert not f2.done()
6165
await a.send_multipart([b"hi", b"there"])
62-
recvd = await f2
63-
assert f1.done()
64-
assert f1.result() == b"hi"
65-
assert recvd == b"there"
66+
assert await f1 == b"hi"
67+
assert await f2 == b"there"
6668

6769

6870
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
@@ -72,82 +74,70 @@ async def test_recv_timeout(push_pull):
7274
f1 = b.recv()
7375
b.rcvtimeo = 1000
7476
f2 = b.recv_multipart()
75-
with pytest.raises(zmq.Again):
77+
with pytest.raises(ExceptionGroup) as excinfo:
7678
await f1
79+
assert excinfo.group_contains(zmq.Again)
7780
await a.send_multipart([b"hi", b"there"])
7881
recvd = await f2
79-
assert f2.done()
8082
assert recvd == [b"hi", b"there"]
8183

8284

8385
@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
8486
async def test_send_timeout(socket):
8587
s = socket(zmq.PUSH)
8688
s.sndtimeo = 100
87-
with pytest.raises(zmq.Again):
89+
with pytest.raises(ExceptionGroup) as excinfo:
8890
await s.send(b"not going anywhere")
91+
assert excinfo.group_contains(zmq.Again)
8992

9093

9194
async def test_recv_string(push_pull):
9295
a, b = push_pull
9396
f = b.recv_string()
94-
assert not f.done()
9597
msg = "πøøπ"
9698
await a.send_string(msg)
9799
recvd = await f
98-
assert f.done()
99-
assert f.result() == msg
100100
assert recvd == msg
101101

102102

103103
async def test_recv_json(push_pull):
104104
a, b = push_pull
105105
f = b.recv_json()
106-
assert not f.done()
107106
obj = dict(a=5)
108107
await a.send_json(obj)
109108
recvd = await f
110-
assert f.done()
111-
assert f.result() == obj
112109
assert recvd == obj
113110

114111

115112
async def test_recv_json_cancelled(push_pull):
116-
a, b = push_pull
117-
f = b.recv_json()
118-
assert not f.done()
119-
f.cancel()
120-
# cycle eventloop to allow cancel events to fire
121-
await asyncio.sleep(0)
122-
obj = dict(a=5)
123-
await a.send_json(obj)
124-
# CancelledError change in 3.8 https://bugs.python.org/issue32528
125-
if sys.version_info < (3, 8):
126-
with pytest.raises(CancelledError):
113+
async with create_task_group() as tg:
114+
a, b = push_pull
115+
f = create_task(b.recv_json(), tg)
116+
f.cancel(raise_exception=False)
117+
# cycle eventloop to allow cancel events to fire
118+
await sleep(0)
119+
obj = dict(a=5)
120+
await a.send_json(obj)
121+
recvd = await f.wait()
122+
assert f.cancelled()
123+
assert f.done()
124+
# give it a chance to incorrectly consume the event
125+
events = await b.poll(timeout=5)
126+
assert events
127+
await sleep(0)
128+
# make sure cancelled recv didn't eat up event
129+
f = b.recv_json()
130+
with move_on_after(5):
127131
recvd = await f
128-
else:
129-
with pytest.raises(asyncio.exceptions.CancelledError):
130-
recvd = await f
131-
assert f.done()
132-
# give it a chance to incorrectly consume the event
133-
events = await b.poll(timeout=5)
134-
assert events
135-
await asyncio.sleep(0)
136-
# make sure cancelled recv didn't eat up event
137-
f = b.recv_json()
138-
recvd = await asyncio.wait_for(f, timeout=5)
139-
assert recvd == obj
132+
assert recvd == obj
140133

141134

142135
async def test_recv_pyobj(push_pull):
143136
a, b = push_pull
144137
f = b.recv_pyobj()
145-
assert not f.done()
146138
obj = dict(a=5)
147139
await a.send_pyobj(obj)
148140
recvd = await f
149-
assert f.done()
150-
assert f.result() == obj
151141
assert recvd == obj
152142

153143

@@ -206,85 +196,90 @@ async def test_custom_serialize_error(dealer_router):
206196
async def test_recv_dontwait(push_pull):
207197
push, pull = push_pull
208198
f = pull.recv(zmq.DONTWAIT)
209-
with pytest.raises(zmq.Again):
199+
with pytest.raises(BaseExceptionGroup) as excinfo:
210200
await f
201+
assert excinfo.group_contains(zmq.Again)
211202
await push.send(b"ping")
212203
await pull.poll() # ensure message will be waiting
213-
f = pull.recv(zmq.DONTWAIT)
214-
assert f.done()
215-
msg = await f
204+
msg = await pull.recv(zmq.DONTWAIT)
216205
assert msg == b"ping"
217206

218207

219208
async def test_recv_cancel(push_pull):
220-
a, b = push_pull
221-
f1 = b.recv()
222-
f2 = b.recv_multipart()
223-
assert f1.cancel()
224-
assert f1.done()
225-
assert not f2.done()
226-
await a.send_multipart([b"hi", b"there"])
227-
recvd = await f2
228-
assert f1.cancelled()
229-
assert f2.done()
230-
assert recvd == [b"hi", b"there"]
209+
async with create_task_group() as tg:
210+
a, b = push_pull
211+
f1 = create_task(b.recv(), tg)
212+
f2 = create_task(b.recv_multipart(), tg)
213+
f1.cancel(raise_exception=False)
214+
assert f1.done()
215+
assert not f2.done()
216+
await a.send_multipart([b"hi", b"there"])
217+
recvd = await f2.wait()
218+
assert f1.cancelled()
219+
assert f2.done()
220+
assert recvd == [b"hi", b"there"]
231221

232222

233223
async def test_poll(push_pull):
234-
a, b = push_pull
235-
f = b.poll(timeout=0)
236-
await asyncio.sleep(0)
237-
assert f.result() == 0
224+
async with create_task_group() as tg:
225+
a, b = push_pull
226+
f = create_task(b.poll(timeout=0), tg)
227+
await sleep(0.01)
228+
assert f.result() == 0
238229

239-
f = b.poll(timeout=1)
240-
assert not f.done()
241-
evt = await f
230+
f = create_task(b.poll(timeout=1), tg)
231+
assert not f.done()
232+
evt = await f.wait()
242233

243-
assert evt == 0
234+
assert evt == 0
244235

245-
f = b.poll(timeout=1000)
246-
assert not f.done()
247-
await a.send_multipart([b"hi", b"there"])
248-
evt = await f
249-
assert evt == zmq.POLLIN
250-
recvd = await b.recv_multipart()
251-
assert recvd == [b"hi", b"there"]
236+
f = create_task(b.poll(timeout=1000), tg)
237+
assert not f.done()
238+
await a.send_multipart([b"hi", b"there"])
239+
evt = await f.wait()
240+
assert evt == zmq.POLLIN
241+
recvd = await b.recv_multipart()
242+
assert recvd == [b"hi", b"there"]
252243

253244

254245
async def test_poll_base_socket(sockets):
255-
ctx = zmq.Context()
256-
url = "inproc://test"
257-
a = ctx.socket(zmq.PUSH)
258-
b = ctx.socket(zmq.PULL)
259-
sockets.extend([a, b])
260-
a.bind(url)
261-
b.connect(url)
262-
263-
poller = zaio.Poller()
264-
poller.register(b, zmq.POLLIN)
265-
266-
f = poller.poll(timeout=1000)
267-
assert not f.done()
268-
a.send_multipart([b"hi", b"there"])
269-
evt = await f
270-
assert evt == [(b, zmq.POLLIN)]
271-
recvd = b.recv_multipart()
272-
assert recvd == [b"hi", b"there"]
246+
async with create_task_group() as tg:
247+
ctx = zmq.Context()
248+
url = "inproc://test"
249+
a = ctx.socket(zmq.PUSH)
250+
b = ctx.socket(zmq.PULL)
251+
sockets.extend([a, b])
252+
a.bind(url)
253+
b.connect(url)
254+
255+
poller = zaio.Poller()
256+
poller.register(b, zmq.POLLIN)
257+
258+
f = create_task(poller.poll(timeout=1000), tg)
259+
assert not f.done()
260+
a.send_multipart([b"hi", b"there"])
261+
evt = await f.wait()
262+
assert evt == [(b, zmq.POLLIN)]
263+
recvd = b.recv_multipart()
264+
assert recvd == [b"hi", b"there"]
273265

274266

275267
async def test_poll_on_closed_socket(push_pull):
276-
a, b = push_pull
268+
with pytest.raises(BaseExceptionGroup) as excinfo:
269+
async with create_task_group() as tg:
270+
a, b = push_pull
277271

278-
f = b.poll(timeout=1)
279-
b.close()
272+
f = create_task(b.poll(timeout=1), tg)
273+
b.close()
280274

281-
# The test might stall if we try to await f directly so instead just make a few
282-
# passes through the event loop to schedule and execute all callbacks
283-
for _ in range(5):
284-
await asyncio.sleep(0)
285-
if f.cancelled():
286-
break
287-
assert f.cancelled()
275+
# The test might stall if we try to await f directly so instead just make a few
276+
# passes through the event loop to schedule and execute all callbacks
277+
for _ in range(5):
278+
await sleep(0)
279+
if f.cancelled():
280+
break
281+
assert f.done()
282+
assert excinfo.group_contains(zmq.error.ZMQError)
288283

289284

290285
@pytest.mark.skipif(
@@ -344,16 +339,17 @@ def test_shadow():
344339

345340

346341
async def test_poll_leak():
347-
ctx = zmq.asyncio.Context()
348-
with ctx, ctx.socket(zmq.PULL) as s:
349-
assert len(s._recv_futures) == 0
350-
for i in range(10):
351-
f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN))
352-
f.cancel()
353-
await asyncio.sleep(0)
354-
# one more sleep allows further chained cleanup
355-
await asyncio.sleep(0.1)
356-
assert len(s._recv_futures) == 0
342+
async with create_task_group() as tg:
343+
ctx = zmq.asyncio.Context()
344+
with ctx, ctx.socket(zmq.PULL) as s:
345+
assert len(s._recv_futures) == 0
346+
for i in range(10):
347+
f = create_task(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN), tg)
348+
f.cancel(raise_exception=False)
349+
await sleep(0)
350+
# one more sleep allows further chained cleanup
351+
await sleep(0.1)
352+
assert len(s._recv_futures) == 0
357353

358354

359355
class ProcessForTeardownTest(Process):

tests/test_ioloop.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
_tornado = True
1313

1414

15-
def setup():
16-
if not _tornado:
17-
pytest.skip("requires tornado")
15+
if not _tornado:
16+
pytest.skip("requires tornado", allow_module_level=True)
1817

1918

2019
def test_ioloop():

0 commit comments

Comments
 (0)