15
15
import starlette .websockets
16
16
import uvicorn .server
17
17
import websockets .legacy .http
18
+ import websockets .exceptions
18
19
19
20
from solara .server .utils import path_is_child_of
20
21
@@ -121,9 +122,31 @@ async def process_messages_task(self):
121
122
while len (self .to_send ) > 0 :
122
123
first = self .to_send .pop (0 )
123
124
if isinstance (first , bytes ):
124
- await self .ws . send_bytes (first )
125
+ await self ._send_bytes_exc (first )
125
126
else :
126
- await self .ws .send_text (first )
127
+ await self ._send_text_exc (first )
128
+
129
+ async def _send_bytes_exc (self , data : bytes ):
130
+ # make sures we catch the starlette/websockets specific exception
131
+ # and re-raise it as a websocket.WebSocketDisconnect
132
+ try :
133
+ await self .ws .send_bytes (data )
134
+ except websockets .exceptions .ConnectionClosed as e :
135
+ raise websocket .WebSocketDisconnect () from e
136
+ except RuntimeError as e :
137
+ # starlette throws a RuntimeError once you call send after the connection is closed
138
+ raise websocket .WebSocketDisconnect () from e
139
+
140
+ async def _send_text_exc (self , data : str ):
141
+ # make sures we catch the starlette/websockets specific exception
142
+ # and re-raise it as a websocket.WebSocketDisconnect
143
+ try :
144
+ await self .ws .send_text (data )
145
+ except websockets .exceptions .ConnectionClosed as e :
146
+ raise websocket .WebSocketDisconnect () from e
147
+ except RuntimeError as e :
148
+ # starlette throws a RuntimeError once you call send after the connection is closed
149
+ raise websocket .WebSocketDisconnect () from e
127
150
128
151
def close (self ):
129
152
if self .portal is None :
@@ -133,25 +156,25 @@ def close(self):
133
156
134
157
def send_text (self , data : str ) -> None :
135
158
if self .portal is None :
136
- task = self .event_loop .create_task (self .ws . send_text (data ))
159
+ task = self .event_loop .create_task (self ._send_text_exc (data ))
137
160
self .tasks .add (task )
138
161
task .add_done_callback (self .tasks .discard )
139
162
else :
140
163
if settings .main .experimental_performance :
141
164
self .to_send .append (data )
142
165
else :
143
- self .portal .call (self .ws . send_bytes , data ) # type: ignore
166
+ self .portal .call (self ._send_bytes_exc , data ) # type: ignore
144
167
145
168
def send_bytes (self , data : bytes ) -> None :
146
169
if self .portal is None :
147
- task = self .event_loop .create_task (self .ws . send_bytes (data ))
170
+ task = self .event_loop .create_task (self ._send_bytes_exc (data ))
148
171
self .tasks .add (task )
149
172
task .add_done_callback (self .tasks .discard )
150
173
else :
151
174
if settings .main .experimental_performance :
152
175
self .to_send .append (data )
153
176
else :
154
- self .portal .call (self .ws . send_bytes , data ) # type: ignore
177
+ self .portal .call (self ._send_bytes_exc , data ) # type: ignore
155
178
156
179
async def receive (self ):
157
180
if self .portal is None :
0 commit comments