Skip to content

Commit f3f976d

Browse files
fix: a failing websocket.send would drop the connection silently
Instead of assuming all exceptions are due to a closed connection, we only ignored the exception when it is due to a closed connection. This caused issues in #841 where we called from the same thread as the websocket's portal, which caused the connection to be ignored without any error message.
1 parent a980aa5 commit f3f976d

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

solara/server/flask.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,17 @@ def close(self):
6565

6666
def send_text(self, data: str) -> None:
6767
with self.lock:
68-
self.ws.send(data)
68+
try:
69+
self.ws.send(data)
70+
except simple_websocket.ws.ConnectionClosed:
71+
raise websocket.WebSocketDisconnect()
6972

7073
def send_bytes(self, data: bytes) -> None:
7174
with self.lock:
72-
self.ws.send(data)
75+
try:
76+
self.ws.send(data)
77+
except simple_websocket.ws.ConnectionClosed:
78+
raise websocket.WebSocketDisconnect()
7379

7480
async def receive(self):
7581
from anyio import to_thread

solara/server/kernel.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,20 @@ def send_websockets(websockets: Set[websocket.WebsocketWrapper], binary_msg):
215215
for ws in list(websockets):
216216
try:
217217
ws.send(binary_msg)
218-
except: # noqa
219-
# in case of any issue, we simply remove it from the list
218+
except websocket.WebSocketDisconnect:
219+
# ignore the exception, we tried to send while websocket closed
220+
# just remove it from the websocket set
221+
try:
222+
# websocket can be modified by another thread
223+
websockets.remove(ws)
224+
except KeyError:
225+
pass # already removed
226+
except Exception as e: # noqa
227+
logger.exception("Error sending message: %s, closing websocket", e)
228+
try:
229+
ws.close()
230+
except Exception as e: # noqa
231+
logger.exception("Error closing websocket: %s", e)
220232
try:
221233
# websocket can be modified by another thread
222234
websockets.remove(ws)

solara/server/starlette.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import starlette.websockets
1616
import uvicorn.server
1717
import websockets.legacy.http
18+
import websockets.exceptions
1819

1920
from solara.server.utils import path_is_child_of
2021

@@ -121,9 +122,31 @@ async def process_messages_task(self):
121122
while len(self.to_send) > 0:
122123
first = self.to_send.pop(0)
123124
if isinstance(first, bytes):
124-
await self.ws.send_bytes(first)
125+
await self._send_bytes_exc(first)
125126
else:
126-
await self.ws.send_text(first)
127+
await self._send_text_exc(first)
128+
129+
async def _send_bytes_exc(self, data: bytes):
130+
# make sures we catch the starlette/websockets specific exception
131+
# and re-raise it as a websocket.WebSocketDisconnect
132+
try:
133+
await self.ws.send_bytes(data)
134+
except websockets.exceptions.ConnectionClosed as e:
135+
raise websocket.WebSocketDisconnect() from e
136+
except RuntimeError as e:
137+
# starlette throws a RuntimeError once you call send after the connection is closed
138+
raise websocket.WebSocketDisconnect() from e
139+
140+
async def _send_text_exc(self, data: str):
141+
# make sures we catch the starlette/websockets specific exception
142+
# and re-raise it as a websocket.WebSocketDisconnect
143+
try:
144+
await self.ws.send_text(data)
145+
except websockets.exceptions.ConnectionClosed as e:
146+
raise websocket.WebSocketDisconnect() from e
147+
except RuntimeError as e:
148+
# starlette throws a RuntimeError once you call send after the connection is closed
149+
raise websocket.WebSocketDisconnect() from e
127150

128151
def close(self):
129152
if self.portal is None:
@@ -133,25 +156,25 @@ def close(self):
133156

134157
def send_text(self, data: str) -> None:
135158
if self.portal is None:
136-
task = self.event_loop.create_task(self.ws.send_text(data))
159+
task = self.event_loop.create_task(self._send_text_exc(data))
137160
self.tasks.add(task)
138161
task.add_done_callback(self.tasks.discard)
139162
else:
140163
if settings.main.experimental_performance:
141164
self.to_send.append(data)
142165
else:
143-
self.portal.call(self.ws.send_bytes, data) # type: ignore
166+
self.portal.call(self._send_bytes_exc, data) # type: ignore
144167

145168
def send_bytes(self, data: bytes) -> None:
146169
if self.portal is None:
147-
task = self.event_loop.create_task(self.ws.send_bytes(data))
170+
task = self.event_loop.create_task(self._send_bytes_exc(data))
148171
self.tasks.add(task)
149172
task.add_done_callback(self.tasks.discard)
150173
else:
151174
if settings.main.experimental_performance:
152175
self.to_send.append(data)
153176
else:
154-
self.portal.call(self.ws.send_bytes, data) # type: ignore
177+
self.portal.call(self._send_bytes_exc, data) # type: ignore
155178

156179
async def receive(self):
157180
if self.portal is None:

0 commit comments

Comments
 (0)