11
11
from multiprocessing import Process
12
12
13
13
import pytest
14
+ from anyio import create_task_group , move_on_after , sleep
15
+ from anyioutils import CancelledError , create_task
14
16
from pytest import mark
15
17
16
18
import zmq
17
19
import zmq .asyncio as zaio
18
20
21
+ if sys .version_info < (3 , 11 ):
22
+ from exceptiongroup import BaseExceptionGroup , ExceptionGroup
23
+
24
+
25
+ pytestmark = pytest .mark .anyio
26
+
19
27
20
28
@pytest .fixture
21
29
def Context (event_loop ):
@@ -46,23 +54,17 @@ def test_instance_subclass_second(context):
46
54
async def test_recv_multipart (context , create_bound_pair ):
47
55
a , b = create_bound_pair (zmq .PUSH , zmq .PULL )
48
56
f = b .recv_multipart ()
49
- assert not f .done ()
50
57
await a .send (b"hi" )
51
- recvd = await f
52
- assert recvd == [b"hi" ]
58
+ assert await f == [b"hi" ]
53
59
54
60
55
61
async def test_recv (create_bound_pair ):
56
62
a , b = create_bound_pair (zmq .PUSH , zmq .PULL )
57
63
f1 = b .recv ()
58
64
f2 = b .recv ()
59
- assert not f1 .done ()
60
- assert not f2 .done ()
61
65
await a .send_multipart ([b"hi" , b"there" ])
62
- recvd = await f2
63
- assert f1 .done ()
64
- assert f1 .result () == b"hi"
65
- assert recvd == b"there"
66
+ assert await f1 == b"hi"
67
+ assert await f2 == b"there"
66
68
67
69
68
70
@mark .skipif (not hasattr (zmq , "RCVTIMEO" ), reason = "requires RCVTIMEO" )
@@ -72,82 +74,70 @@ async def test_recv_timeout(push_pull):
72
74
f1 = b .recv ()
73
75
b .rcvtimeo = 1000
74
76
f2 = b .recv_multipart ()
75
- with pytest .raises (zmq . Again ) :
77
+ with pytest .raises (ExceptionGroup ) as excinfo :
76
78
await f1
79
+ assert excinfo .group_contains (zmq .Again )
77
80
await a .send_multipart ([b"hi" , b"there" ])
78
81
recvd = await f2
79
- assert f2 .done ()
80
82
assert recvd == [b"hi" , b"there" ]
81
83
82
84
83
85
@mark .skipif (not hasattr (zmq , "SNDTIMEO" ), reason = "requires SNDTIMEO" )
84
86
async def test_send_timeout (socket ):
85
87
s = socket (zmq .PUSH )
86
88
s .sndtimeo = 100
87
- with pytest .raises (zmq . Again ) :
89
+ with pytest .raises (ExceptionGroup ) as excinfo :
88
90
await s .send (b"not going anywhere" )
91
+ assert excinfo .group_contains (zmq .Again )
89
92
90
93
91
94
async def test_recv_string (push_pull ):
92
95
a , b = push_pull
93
96
f = b .recv_string ()
94
- assert not f .done ()
95
97
msg = "πøøπ"
96
98
await a .send_string (msg )
97
99
recvd = await f
98
- assert f .done ()
99
- assert f .result () == msg
100
100
assert recvd == msg
101
101
102
102
103
103
async def test_recv_json (push_pull ):
104
104
a , b = push_pull
105
105
f = b .recv_json ()
106
- assert not f .done ()
107
106
obj = dict (a = 5 )
108
107
await a .send_json (obj )
109
108
recvd = await f
110
- assert f .done ()
111
- assert f .result () == obj
112
109
assert recvd == obj
113
110
114
111
115
112
async def test_recv_json_cancelled (push_pull ):
116
- a , b = push_pull
117
- f = b .recv_json ()
118
- assert not f .done ()
119
- f .cancel ()
120
- # cycle eventloop to allow cancel events to fire
121
- await asyncio .sleep (0 )
122
- obj = dict (a = 5 )
123
- await a .send_json (obj )
124
- # CancelledError change in 3.8 https://bugs.python.org/issue32528
125
- if sys .version_info < (3 , 8 ):
126
- with pytest .raises (CancelledError ):
113
+ async with create_task_group () as tg :
114
+ a , b = push_pull
115
+ f = create_task (b .recv_json (), tg )
116
+ f .cancel (raise_exception = False )
117
+ # cycle eventloop to allow cancel events to fire
118
+ await sleep (0 )
119
+ obj = dict (a = 5 )
120
+ await a .send_json (obj )
121
+ recvd = await f .wait ()
122
+ assert f .cancelled ()
123
+ assert f .done ()
124
+ # give it a chance to incorrectly consume the event
125
+ events = await b .poll (timeout = 5 )
126
+ assert events
127
+ await sleep (0 )
128
+ # make sure cancelled recv didn't eat up event
129
+ f = b .recv_json ()
130
+ with move_on_after (5 ):
127
131
recvd = await f
128
- else :
129
- with pytest .raises (asyncio .exceptions .CancelledError ):
130
- recvd = await f
131
- assert f .done ()
132
- # give it a chance to incorrectly consume the event
133
- events = await b .poll (timeout = 5 )
134
- assert events
135
- await asyncio .sleep (0 )
136
- # make sure cancelled recv didn't eat up event
137
- f = b .recv_json ()
138
- recvd = await asyncio .wait_for (f , timeout = 5 )
139
- assert recvd == obj
132
+ assert recvd == obj
140
133
141
134
142
135
async def test_recv_pyobj (push_pull ):
143
136
a , b = push_pull
144
137
f = b .recv_pyobj ()
145
- assert not f .done ()
146
138
obj = dict (a = 5 )
147
139
await a .send_pyobj (obj )
148
140
recvd = await f
149
- assert f .done ()
150
- assert f .result () == obj
151
141
assert recvd == obj
152
142
153
143
@@ -206,85 +196,90 @@ async def test_custom_serialize_error(dealer_router):
206
196
async def test_recv_dontwait (push_pull ):
207
197
push , pull = push_pull
208
198
f = pull .recv (zmq .DONTWAIT )
209
- with pytest .raises (zmq . Again ) :
199
+ with pytest .raises (BaseExceptionGroup ) as excinfo :
210
200
await f
201
+ assert excinfo .group_contains (zmq .Again )
211
202
await push .send (b"ping" )
212
203
await pull .poll () # ensure message will be waiting
213
- f = pull .recv (zmq .DONTWAIT )
214
- assert f .done ()
215
- msg = await f
204
+ msg = await pull .recv (zmq .DONTWAIT )
216
205
assert msg == b"ping"
217
206
218
207
219
208
async def test_recv_cancel (push_pull ):
220
- a , b = push_pull
221
- f1 = b .recv ()
222
- f2 = b .recv_multipart ()
223
- assert f1 .cancel ()
224
- assert f1 .done ()
225
- assert not f2 .done ()
226
- await a .send_multipart ([b"hi" , b"there" ])
227
- recvd = await f2
228
- assert f1 .cancelled ()
229
- assert f2 .done ()
230
- assert recvd == [b"hi" , b"there" ]
209
+ async with create_task_group () as tg :
210
+ a , b = push_pull
211
+ f1 = create_task (b .recv (), tg )
212
+ f2 = create_task (b .recv_multipart (), tg )
213
+ f1 .cancel (raise_exception = False )
214
+ assert f1 .done ()
215
+ assert not f2 .done ()
216
+ await a .send_multipart ([b"hi" , b"there" ])
217
+ recvd = await f2 .wait ()
218
+ assert f1 .cancelled ()
219
+ assert f2 .done ()
220
+ assert recvd == [b"hi" , b"there" ]
231
221
232
222
233
223
async def test_poll (push_pull ):
234
- a , b = push_pull
235
- f = b .poll (timeout = 0 )
236
- await asyncio .sleep (0 )
237
- assert f .result () == 0
224
+ async with create_task_group () as tg :
225
+ a , b = push_pull
226
+ f = create_task (b .poll (timeout = 0 ), tg )
227
+ await sleep (0.01 )
228
+ assert f .result () == 0
238
229
239
- f = b .poll (timeout = 1 )
240
- assert not f .done ()
241
- evt = await f
230
+ f = create_task ( b .poll (timeout = 1 ), tg )
231
+ assert not f .done ()
232
+ evt = await f . wait ()
242
233
243
- assert evt == 0
234
+ assert evt == 0
244
235
245
- f = b .poll (timeout = 1000 )
246
- assert not f .done ()
247
- await a .send_multipart ([b"hi" , b"there" ])
248
- evt = await f
249
- assert evt == zmq .POLLIN
250
- recvd = await b .recv_multipart ()
251
- assert recvd == [b"hi" , b"there" ]
236
+ f = create_task ( b .poll (timeout = 1000 ), tg )
237
+ assert not f .done ()
238
+ await a .send_multipart ([b"hi" , b"there" ])
239
+ evt = await f . wait ()
240
+ assert evt == zmq .POLLIN
241
+ recvd = await b .recv_multipart ()
242
+ assert recvd == [b"hi" , b"there" ]
252
243
253
244
254
245
async def test_poll_base_socket (sockets ):
255
- ctx = zmq .Context ()
256
- url = "inproc://test"
257
- a = ctx .socket (zmq .PUSH )
258
- b = ctx .socket (zmq .PULL )
259
- sockets .extend ([a , b ])
260
- a .bind (url )
261
- b .connect (url )
262
-
263
- poller = zaio .Poller ()
264
- poller .register (b , zmq .POLLIN )
265
-
266
- f = poller .poll (timeout = 1000 )
267
- assert not f .done ()
268
- a .send_multipart ([b"hi" , b"there" ])
269
- evt = await f
270
- assert evt == [(b , zmq .POLLIN )]
271
- recvd = b .recv_multipart ()
272
- assert recvd == [b"hi" , b"there" ]
246
+ async with create_task_group () as tg :
247
+ ctx = zmq .Context ()
248
+ url = "inproc://test"
249
+ a = ctx .socket (zmq .PUSH )
250
+ b = ctx .socket (zmq .PULL )
251
+ sockets .extend ([a , b ])
252
+ a .bind (url )
253
+ b .connect (url )
254
+
255
+ poller = zaio .Poller ()
256
+ poller .register (b , zmq .POLLIN )
257
+
258
+ f = create_task (poller .poll (timeout = 1000 ), tg )
259
+ assert not f .done ()
260
+ a .send_multipart ([b"hi" , b"there" ])
261
+ evt = await f .wait ()
262
+ assert evt == [(b , zmq .POLLIN )]
263
+ recvd = b .recv_multipart ()
264
+ assert recvd == [b"hi" , b"there" ]
273
265
274
266
275
267
async def test_poll_on_closed_socket (push_pull ):
276
- a , b = push_pull
268
+ with pytest .raises (BaseExceptionGroup ) as excinfo :
269
+ async with create_task_group () as tg :
270
+ a , b = push_pull
277
271
278
- f = b .poll (timeout = 1 )
279
- b .close ()
272
+ f = create_task ( b .poll (timeout = 1 ), tg )
273
+ b .close ()
280
274
281
- # The test might stall if we try to await f directly so instead just make a few
282
- # passes through the event loop to schedule and execute all callbacks
283
- for _ in range (5 ):
284
- await asyncio .sleep (0 )
285
- if f .cancelled ():
286
- break
287
- assert f .cancelled ()
275
+ # The test might stall if we try to await f directly so instead just make a few
276
+ # passes through the event loop to schedule and execute all callbacks
277
+ for _ in range (5 ):
278
+ await sleep (0 )
279
+ if f .cancelled ():
280
+ break
281
+ assert f .done ()
282
+ assert excinfo .group_contains (zmq .error .ZMQError )
288
283
289
284
290
285
@pytest .mark .skipif (
@@ -344,16 +339,17 @@ def test_shadow():
344
339
345
340
346
341
async def test_poll_leak ():
347
- ctx = zmq .asyncio .Context ()
348
- with ctx , ctx .socket (zmq .PULL ) as s :
349
- assert len (s ._recv_futures ) == 0
350
- for i in range (10 ):
351
- f = asyncio .ensure_future (s .poll (timeout = 1000 , flags = zmq .PollEvent .POLLIN ))
352
- f .cancel ()
353
- await asyncio .sleep (0 )
354
- # one more sleep allows further chained cleanup
355
- await asyncio .sleep (0.1 )
356
- assert len (s ._recv_futures ) == 0
342
+ async with create_task_group () as tg :
343
+ ctx = zmq .asyncio .Context ()
344
+ with ctx , ctx .socket (zmq .PULL ) as s :
345
+ assert len (s ._recv_futures ) == 0
346
+ for i in range (10 ):
347
+ f = create_task (s .poll (timeout = 1000 , flags = zmq .PollEvent .POLLIN ), tg )
348
+ f .cancel (raise_exception = False )
349
+ await sleep (0 )
350
+ # one more sleep allows further chained cleanup
351
+ await sleep (0.1 )
352
+ assert len (s ._recv_futures ) == 0
357
353
358
354
359
355
class ProcessForTeardownTest (Process ):
0 commit comments