Skip to content

Commit 858ce88

Browse files
committed
feature: improve layer typings, add method stubs
1 parent 0933260 commit 858ce88

File tree

1 file changed

+71
-30
lines changed

1 file changed

+71
-30
lines changed

channels/layers.py

+71-30
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import fnmatch
35
import random
46
import re
57
import string
68
import time
79
from copy import deepcopy
10+
from typing import Dict, Iterable, List, Optional, Tuple
811

912
from django.conf import settings
1013
from django.core.signals import setting_changed
@@ -20,6 +23,8 @@ class ChannelLayerManager:
2023
Takes a settings dictionary of backends and initialises them on request.
2124
"""
2225

26+
backends: Dict[str, BaseChannelLayer]
27+
2328
def __init__(self):
2429
self.backends = {}
2530
setting_changed.connect(self._reset_backends)
@@ -36,14 +41,14 @@ def configs(self):
3641
# Lazy load settings so we can be imported
3742
return getattr(settings, "CHANNEL_LAYERS", {})
3843

39-
def make_backend(self, name):
44+
def make_backend(self, name) -> BaseChannelLayer:
4045
"""
4146
Instantiate channel layer.
4247
"""
4348
config = self.configs[name].get("CONFIG", {})
4449
return self._make_backend(name, config)
4550

46-
def make_test_backend(self, name):
51+
def make_test_backend(self, name) -> BaseChannelLayer:
4752
"""
4853
Instantiate channel layer using its test config.
4954
"""
@@ -53,7 +58,7 @@ def make_test_backend(self, name):
5358
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
5459
return self._make_backend(name, config)
5560

56-
def _make_backend(self, name, config):
61+
def _make_backend(self, name, config) -> BaseChannelLayer:
5762
# Check for old format config
5863
if "ROUTING" in self.configs[name]:
5964
raise InvalidChannelLayerError(
@@ -81,7 +86,7 @@ def __getitem__(self, key):
8186
def __contains__(self, key):
8287
return key in self.configs
8388

84-
def set(self, key, layer):
89+
def set(self, key: str, layer: BaseChannelLayer):
8590
"""
8691
Sets an alias to point to a new ChannelLayerWrapper instance, and
8792
returns the old one that it replaced. Useful for swapping out the
@@ -99,13 +104,21 @@ class BaseChannelLayer:
99104
"""
100105

101106
MAX_NAME_LENGTH = 100
107+
extensions: Iterable[str] = ()
102108

103-
def __init__(self, expiry=60, capacity=100, channel_capacity=None):
109+
def __init__(
110+
self,
111+
expiry=60,
112+
capacity: Optional[int] = 100,
113+
channel_capacity: Optional[int] = None,
114+
):
104115
self.expiry = expiry
105116
self.capacity = capacity
106117
self.channel_capacity = channel_capacity or {}
107118

108-
def compile_capacities(self, channel_capacity):
119+
def compile_capacities(
120+
self, channel_capacity
121+
) -> List[Tuple[re.Pattern, Optional[int]]]:
109122
"""
110123
Takes an input channel_capacity dict and returns the compiled list
111124
of regexes that get_capacity will look for as self.channel_capacity
@@ -120,7 +133,7 @@ def compile_capacities(self, channel_capacity):
120133
result.append((re.compile(fnmatch.translate(pattern)), value))
121134
return result
122135

123-
def get_capacity(self, channel):
136+
def get_capacity(self, channel: str) -> Optional[int]:
124137
"""
125138
Gets the correct capacity for the given channel; either the default,
126139
or a matching result from channel_capacity. Returns the first matching
@@ -132,7 +145,7 @@ def get_capacity(self, channel):
132145
return capacity
133146
return self.capacity
134147

135-
def match_type_and_length(self, name):
148+
def match_type_and_length(self, name) -> bool:
136149
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
137150
return True
138151
return False
@@ -148,7 +161,7 @@ def match_type_and_length(self, name):
148161
+ "not {}"
149162
)
150163

151-
def valid_channel_name(self, name, receive=False):
164+
def valid_channel_name(self, name: str, receive=False) -> bool:
152165
if self.match_type_and_length(name):
153166
if bool(self.channel_name_regex.match(name)):
154167
# Check cases for special channels
@@ -159,13 +172,13 @@ def valid_channel_name(self, name, receive=False):
159172
return True
160173
raise TypeError(self.invalid_name_error.format("Channel", name))
161174

162-
def valid_group_name(self, name):
175+
def valid_group_name(self, name: str) -> bool:
163176
if self.match_type_and_length(name):
164177
if bool(self.group_name_regex.match(name)):
165178
return True
166179
raise TypeError(self.invalid_name_error.format("Group", name))
167180

168-
def valid_channel_names(self, names, receive=False):
181+
def valid_channel_names(self, names, receive=False) -> bool:
169182
_non_empty_list = True if names else False
170183
_names_type = isinstance(names, list)
171184
assert _non_empty_list and _names_type, "names must be a non-empty list"
@@ -175,7 +188,7 @@ def valid_channel_names(self, names, receive=False):
175188
)
176189
return True
177190

178-
def non_local_name(self, name):
191+
def non_local_name(self, name: str) -> str:
179192
"""
180193
Given a channel name, returns the "non-local" part. If the channel name
181194
is a process-specific channel (contains !) this means the part up to
@@ -186,6 +199,49 @@ def non_local_name(self, name):
186199
else:
187200
return name
188201

202+
async def send(self, channel: str, message: dict):
203+
"""
204+
Send a message onto a (general or specific) channel.
205+
"""
206+
raise NotImplementedError()
207+
208+
async def receive(self, channel: str) -> dict:
209+
"""
210+
Receive the first message that arrives on the channel.
211+
If more than one coroutine waits on the same channel, a random one
212+
of the waiting coroutines will get the result.
213+
"""
214+
raise NotImplementedError()
215+
216+
async def new_channel(self, prefix: str = "specific.") -> str:
217+
"""
218+
Returns a new channel name that can be used by something in our
219+
process as a specific channel.
220+
"""
221+
raise NotImplementedError()
222+
223+
# Flush extension
224+
225+
async def flush(self):
226+
raise NotImplementedError()
227+
228+
async def close(self):
229+
raise NotImplementedError()
230+
231+
# Groups extension
232+
233+
async def group_add(self, group: str, channel: str):
234+
"""
235+
Adds the channel name to a group.
236+
"""
237+
raise NotImplementedError()
238+
239+
async def group_discard(self, group: str, channel: str):
240+
raise NotImplementedError()
241+
242+
async def group_send(self, group: str, message: dict):
243+
raise NotImplementedError()
244+
189245

190246
class InMemoryChannelLayer(BaseChannelLayer):
191247
"""
@@ -198,13 +254,13 @@ def __init__(
198254
group_expiry=86400,
199255
capacity=100,
200256
channel_capacity=None,
201-
**kwargs
257+
**kwargs,
202258
):
203259
super().__init__(
204260
expiry=expiry,
205261
capacity=capacity,
206262
channel_capacity=channel_capacity,
207-
**kwargs
263+
**kwargs,
208264
)
209265
self.channels = {}
210266
self.groups = {}
@@ -215,9 +271,6 @@ def __init__(
215271
extensions = ["groups", "flush"]
216272

217273
async def send(self, channel, message):
218-
"""
219-
Send a message onto a (general or specific) channel.
220-
"""
221274
# Typecheck
222275
assert isinstance(message, dict), "message is not a dict"
223276
assert self.valid_channel_name(channel), "Channel name not valid"
@@ -234,11 +287,6 @@ async def send(self, channel, message):
234287
await queue.put((time.time() + self.expiry, deepcopy(message)))
235288

236289
async def receive(self, channel):
237-
"""
238-
Receive the first message that arrives on the channel.
239-
If more than one coroutine waits on the same channel, a random one
240-
of the waiting coroutines will get the result.
241-
"""
242290
assert self.valid_channel_name(channel)
243291
self._clean_expired()
244292

@@ -254,10 +302,6 @@ async def receive(self, channel):
254302
return message
255303

256304
async def new_channel(self, prefix="specific."):
257-
"""
258-
Returns a new channel name that can be used by something in our
259-
process as a specific channel.
260-
"""
261305
return "%s.inmemory!%s" % (
262306
prefix,
263307
"".join(random.choice(string.ascii_letters) for i in range(12)),
@@ -314,9 +358,6 @@ def _remove_from_groups(self, channel):
314358
# Groups extension
315359

316360
async def group_add(self, group, channel):
317-
"""
318-
Adds the channel name to a group.
319-
"""
320361
# Check the inputs
321362
assert self.valid_group_name(group), "Group name not valid"
322363
assert self.valid_channel_name(channel), "Channel name not valid"
@@ -349,7 +390,7 @@ async def group_send(self, group, message):
349390
pass
350391

351392

352-
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
393+
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER) -> Optional[BaseChannelLayer]:
353394
"""
354395
Returns a channel layer by alias, or None if it is not configured.
355396
"""

0 commit comments

Comments
 (0)