-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwebsocket_channels.py
178 lines (151 loc) · 6.43 KB
/
websocket_channels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
Based on Flask-Sockets (https://github.com/kennethreitz/flask-sockets) and
https://devcenter.heroku.com/articles/python-websockets
"""
import functools
import logging
import redis
import gevent
import geventwebsocket.gunicorn.workers
logger = logging.getLogger(__name__)
logger.setLevel('WARNING')
class Worker(geventwebsocket.gunicorn.workers.GeventWebSocketWorker):
"""The worker used here.
"""
# TODO: It would be nice to hook on WebSocket connection handshake to be able to reject
# TODO: undesired connections
# TODO: https://github.com/abourget/gevent-socketio/blob/master/socketio/sgunicorn.py
def async(func):
"""Decorator to make a function to be executed asynchronously using a Greenlet.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return gevent.spawn(func, *args, **kwargs)
return wrapper
class ChannelSockets(object):
"""Channels and WebSockets registered on them.
"""
def __init__(self, name):
self.name = name
self.websockets = set()
self._subchannels = {}
def __getitem__(self, name):
"""Get a sub-channel.
"""
channel = self._subchannels.get(name)
if channel is None:
channel = self.__class__(self.name + '/' + name)
self._subchannels[name] = channel
return channel
def __iter__(self):
return self._subchannels.itervalues()
class WebSocketChannelMiddleware(object):
"""WSGI middleware around a WSGI application which expects `wsgi.websocket` request
environment value provided by a Gunicorn worker and handles that websocket.
"""
REDIS_CHANNEL_PREFIX = 'websocket:'
def __init__(self, wsgi_app, redis_url):
self.wsgi_app = wsgi_app
self.redis_client = redis.from_url(redis_url)
self.pubsub = self.redis_client.pubsub(ignore_subscribe_messages=True)
self.channel_sockets = ChannelSockets('')
self._listen()
def __call__(self, environ, start_response):
path = environ['PATH_INFO']
if path.startswith('/ws/'):
channel = path[4:].rstrip('/')
websocket = environ['wsgi.websocket']
self._handle_websocket_connection(websocket, channel)
else: # call the wrapped app
return self.wsgi_app(environ, start_response)
def _handle_websocket_connection(self, websocket, channel):
"""Receive messages from a websocket.
"""
self._register_websocket(websocket, channel)
while True:
gevent.sleep(0.05) # switch to send messages
try:
message = websocket.receive()
except geventwebsocket.WebSocketError:
break
if message:
self.on_message(message, channel)
def _register_websocket(self, websocket, channel):
"""Register a websocket so it can be sent published messages.
"""
sockets = self.channel_sockets
for channel in channel.split('/'):
sockets = sockets[channel]
sockets.websockets.add(websocket)
def on_message(self, message, channel):
"""Hook called when a new message from a client via websocket arrives.
The default implementation publishes the message. You can subclass this to apply custom
logic (e.g filtering).
Args:
message (str): message to publish
channel (str): on which channel
"""
self.publish_message(message, channel)
@async
def publish_message(self, message, channel):
"""Asynchronously PUBLISH a message to the given Redis channel. SUBSCRIBEd Redis clients
will be notified about it.
Args:
message (str): message to publish
channel (str): on which channel
"""
logger.info(u'Pusblishing message on channel `%s`: %s', channel, message)
self.redis_client.publish(self.REDIS_CHANNEL_PREFIX + channel, message)
@async
def _listen(self):
"""Listen in a thread for new messages in Redis, and send them to registered web-sockets.
See: https://github.com/andymccurdy/redis-py#publish--subscribe
"""
self.pubsub.psubscribe(self.REDIS_CHANNEL_PREFIX + '*') # listen to all channels
channel_prefix_len = len(self.REDIS_CHANNEL_PREFIX)
while True:
message = self.pubsub.get_message() if self.pubsub.subscribed else None
if not message:
gevent.sleep(0.05) # be nice to the system
continue
channel = message['channel'][channel_prefix_len:]
logger.debug(u'Received a message on channel `%s`: %s', channel, message)
self._send_message(channel, message['data'])
@async
def _send_message(self, channel, message):
"""Asynchronously send a message to websockets handled by this worker on the given channel.
"""
only_subchannels = channel.endswith('/')
if only_subchannels:
logger.info(u'Sending message to clients on sub-channels of `%s`: %s',
channel, message)
else:
logger.info(u'Sending message to clients on channel `%s`: %s', channel, message)
channel_sockets = self.channel_sockets
for channel in channel.split('/'):
if channel:
channel_sockets = channel_sockets[channel]
if only_subchannels:
self._send_message_subchannels(message, channel_sockets)
else:
self._send_message_channel(message, channel_sockets)
def _send_message_channel(self, message, channel_sockets):
"""Send the given meesage only to websockets of the given channel.
"""
websockets = channel_sockets.websockets
for websocket in tuple(websockets): # changes during iteration
# import random
# if not random.randint(0, 10):
# websocket.close()
# continue
try:
websocket.send(message)
except geventwebsocket.WebSocketError:
# discard invalid connection
websockets.remove(websocket)
def _send_message_subchannels(self, message, channel_sockets):
"""Send the given meesage to weboskets only of subchannels of the given channel.
"""
for channel_sockets in channel_sockets:
self._send_message_channel(message, channel_sockets)
self._send_message_subchannels(message, channel_sockets)