1
+ from __future__ import annotations
2
+
1
3
import asyncio
2
4
import fnmatch
3
5
import random
4
6
import re
5
7
import string
6
8
import time
7
9
from copy import deepcopy
10
+ from typing import Dict , Iterable , List , Optional , Tuple
8
11
9
12
from django .conf import settings
10
13
from django .core .signals import setting_changed
@@ -20,6 +23,8 @@ class ChannelLayerManager:
20
23
Takes a settings dictionary of backends and initialises them on request.
21
24
"""
22
25
26
+ backends : Dict [str , BaseChannelLayer ]
27
+
23
28
def __init__ (self ):
24
29
self .backends = {}
25
30
setting_changed .connect (self ._reset_backends )
@@ -36,14 +41,14 @@ def configs(self):
36
41
# Lazy load settings so we can be imported
37
42
return getattr (settings , "CHANNEL_LAYERS" , {})
38
43
39
- def make_backend (self , name ):
44
+ def make_backend (self , name ) -> BaseChannelLayer :
40
45
"""
41
46
Instantiate channel layer.
42
47
"""
43
48
config = self .configs [name ].get ("CONFIG" , {})
44
49
return self ._make_backend (name , config )
45
50
46
- def make_test_backend (self , name ):
51
+ def make_test_backend (self , name ) -> BaseChannelLayer :
47
52
"""
48
53
Instantiate channel layer using its test config.
49
54
"""
@@ -53,7 +58,7 @@ def make_test_backend(self, name):
53
58
raise InvalidChannelLayerError ("No TEST_CONFIG specified for %s" % name )
54
59
return self ._make_backend (name , config )
55
60
56
- def _make_backend (self , name , config ):
61
+ def _make_backend (self , name , config ) -> BaseChannelLayer :
57
62
# Check for old format config
58
63
if "ROUTING" in self .configs [name ]:
59
64
raise InvalidChannelLayerError (
@@ -81,7 +86,7 @@ def __getitem__(self, key):
81
86
def __contains__ (self , key ):
82
87
return key in self .configs
83
88
84
- def set (self , key , layer ):
89
+ def set (self , key : str , layer : BaseChannelLayer ):
85
90
"""
86
91
Sets an alias to point to a new ChannelLayerWrapper instance, and
87
92
returns the old one that it replaced. Useful for swapping out the
@@ -99,13 +104,21 @@ class BaseChannelLayer:
99
104
"""
100
105
101
106
MAX_NAME_LENGTH = 100
107
+ extensions : Iterable [str ] = ()
102
108
103
- def __init__ (self , expiry = 60 , capacity = 100 , channel_capacity = None ):
109
+ def __init__ (
110
+ self ,
111
+ expiry = 60 ,
112
+ capacity : Optional [int ] = 100 ,
113
+ channel_capacity : Optional [int ] = None ,
114
+ ):
104
115
self .expiry = expiry
105
116
self .capacity = capacity
106
117
self .channel_capacity = channel_capacity or {}
107
118
108
- def compile_capacities (self , channel_capacity ):
119
+ def compile_capacities (
120
+ self , channel_capacity
121
+ ) -> List [Tuple [re .Pattern , Optional [int ]]]:
109
122
"""
110
123
Takes an input channel_capacity dict and returns the compiled list
111
124
of regexes that get_capacity will look for as self.channel_capacity
@@ -120,7 +133,7 @@ def compile_capacities(self, channel_capacity):
120
133
result .append ((re .compile (fnmatch .translate (pattern )), value ))
121
134
return result
122
135
123
- def get_capacity (self , channel ) :
136
+ def get_capacity (self , channel : str ) -> Optional [ int ] :
124
137
"""
125
138
Gets the correct capacity for the given channel; either the default,
126
139
or a matching result from channel_capacity. Returns the first matching
@@ -132,7 +145,7 @@ def get_capacity(self, channel):
132
145
return capacity
133
146
return self .capacity
134
147
135
- def match_type_and_length (self , name ):
148
+ def match_type_and_length (self , name ) -> bool :
136
149
if isinstance (name , str ) and (len (name ) < self .MAX_NAME_LENGTH ):
137
150
return True
138
151
return False
@@ -148,7 +161,7 @@ def match_type_and_length(self, name):
148
161
+ "not {}"
149
162
)
150
163
151
- def valid_channel_name (self , name , receive = False ):
164
+ def valid_channel_name (self , name : str , receive = False ) -> bool :
152
165
if self .match_type_and_length (name ):
153
166
if bool (self .channel_name_regex .match (name )):
154
167
# Check cases for special channels
@@ -159,13 +172,13 @@ def valid_channel_name(self, name, receive=False):
159
172
return True
160
173
raise TypeError (self .invalid_name_error .format ("Channel" , name ))
161
174
162
- def valid_group_name (self , name ) :
175
+ def valid_group_name (self , name : str ) -> bool :
163
176
if self .match_type_and_length (name ):
164
177
if bool (self .group_name_regex .match (name )):
165
178
return True
166
179
raise TypeError (self .invalid_name_error .format ("Group" , name ))
167
180
168
- def valid_channel_names (self , names , receive = False ):
181
+ def valid_channel_names (self , names , receive = False ) -> bool :
169
182
_non_empty_list = True if names else False
170
183
_names_type = isinstance (names , list )
171
184
assert _non_empty_list and _names_type , "names must be a non-empty list"
@@ -175,7 +188,7 @@ def valid_channel_names(self, names, receive=False):
175
188
)
176
189
return True
177
190
178
- def non_local_name (self , name ) :
191
+ def non_local_name (self , name : str ) -> str :
179
192
"""
180
193
Given a channel name, returns the "non-local" part. If the channel name
181
194
is a process-specific channel (contains !) this means the part up to
@@ -186,6 +199,49 @@ def non_local_name(self, name):
186
199
else :
187
200
return name
188
201
202
+ async def send (self , channel : str , message : dict ):
203
+ """
204
+ Send a message onto a (general or specific) channel.
205
+ """
206
+ raise NotImplementedError ()
207
+
208
+ async def receive (self , channel : str ) -> dict :
209
+ """
210
+ Receive the first message that arrives on the channel.
211
+ If more than one coroutine waits on the same channel, a random one
212
+ of the waiting coroutines will get the result.
213
+ """
214
+ raise NotImplementedError ()
215
+
216
+ async def new_channel (self , prefix : str = "specific." ) -> str :
217
+ """
218
+ Returns a new channel name that can be used by something in our
219
+ process as a specific channel.
220
+ """
221
+ raise NotImplementedError ()
222
+
223
+ # Flush extension
224
+
225
+ async def flush (self ):
226
+ raise NotImplementedError ()
227
+
228
+ async def close (self ):
229
+ raise NotImplementedError ()
230
+
231
+ # Groups extension
232
+
233
+ async def group_add (self , group : str , channel : str ):
234
+ """
235
+ Adds the channel name to a group.
236
+ """
237
+ raise NotImplementedError ()
238
+
239
+ async def group_discard (self , group : str , channel : str ):
240
+ raise NotImplementedError ()
241
+
242
+ async def group_send (self , group : str , message : dict ):
243
+ raise NotImplementedError ()
244
+
189
245
190
246
class InMemoryChannelLayer (BaseChannelLayer ):
191
247
"""
@@ -198,13 +254,13 @@ def __init__(
198
254
group_expiry = 86400 ,
199
255
capacity = 100 ,
200
256
channel_capacity = None ,
201
- ** kwargs
257
+ ** kwargs ,
202
258
):
203
259
super ().__init__ (
204
260
expiry = expiry ,
205
261
capacity = capacity ,
206
262
channel_capacity = channel_capacity ,
207
- ** kwargs
263
+ ** kwargs ,
208
264
)
209
265
self .channels = {}
210
266
self .groups = {}
@@ -215,9 +271,6 @@ def __init__(
215
271
extensions = ["groups" , "flush" ]
216
272
217
273
async def send (self , channel , message ):
218
- """
219
- Send a message onto a (general or specific) channel.
220
- """
221
274
# Typecheck
222
275
assert isinstance (message , dict ), "message is not a dict"
223
276
assert self .valid_channel_name (channel ), "Channel name not valid"
@@ -234,11 +287,6 @@ async def send(self, channel, message):
234
287
await queue .put ((time .time () + self .expiry , deepcopy (message )))
235
288
236
289
async def receive (self , channel ):
237
- """
238
- Receive the first message that arrives on the channel.
239
- If more than one coroutine waits on the same channel, a random one
240
- of the waiting coroutines will get the result.
241
- """
242
290
assert self .valid_channel_name (channel )
243
291
self ._clean_expired ()
244
292
@@ -254,10 +302,6 @@ async def receive(self, channel):
254
302
return message
255
303
256
304
async def new_channel (self , prefix = "specific." ):
257
- """
258
- Returns a new channel name that can be used by something in our
259
- process as a specific channel.
260
- """
261
305
return "%s.inmemory!%s" % (
262
306
prefix ,
263
307
"" .join (random .choice (string .ascii_letters ) for i in range (12 )),
@@ -314,9 +358,6 @@ def _remove_from_groups(self, channel):
314
358
# Groups extension
315
359
316
360
async def group_add (self , group , channel ):
317
- """
318
- Adds the channel name to a group.
319
- """
320
361
# Check the inputs
321
362
assert self .valid_group_name (group ), "Group name not valid"
322
363
assert self .valid_channel_name (channel ), "Channel name not valid"
@@ -349,7 +390,7 @@ async def group_send(self, group, message):
349
390
pass
350
391
351
392
352
- def get_channel_layer (alias = DEFAULT_CHANNEL_LAYER ):
393
+ def get_channel_layer (alias = DEFAULT_CHANNEL_LAYER ) -> Optional [ BaseChannelLayer ] :
353
394
"""
354
395
Returns a channel layer by alias, or None if it is not configured.
355
396
"""
0 commit comments