1
1
#!/usr/bin/python
2
2
import requests
3
3
from WebChatGPT import utils
4
- import logging
5
4
import json
6
5
import re
7
6
from functools import lru_cache
8
- import websocket
9
- from base64 import b64decode
10
- from WebChatGPT .errors import WebSocketError
11
- from threading import Thread as thr
12
7
from typing import Iterator
13
- from .errors import MaximumRetrialError
14
-
15
-
16
- class Websocket :
17
-
18
- def __init__ (
19
- self ,
20
- data : dict ,
21
- chatgpt : object ,
22
- trace : bool = False ,
23
- ):
24
- chatgpt .socket_closed = False
25
- chatgpt .loading_chunk = ""
26
- self .payload = data .copy ()
27
- self .url = data .get ("wss_url" )
28
- self .payload .pop ("wss_url" )
29
- self .chatgpt = chatgpt
30
- self .last_response_chunk : dict = {}
31
- self .last_response_undecoded_chunk : dict = {}
32
- websocket .enableTrace (trace )
33
-
34
- def on_message (self , ws , message ):
35
- response = json .loads (message )
36
- self .chatgpt .last_response_undecoded_chunk = response
37
- decoded_body = b64decode (response ["body" ]).decode ("utf-8" )
38
- response ["body" ] = decoded_body
39
- self .chatgpt .last_response_chunk = response
40
- self .chatgpt .loading_chunk = decoded_body
41
-
42
- def on_error (self , ws , error ):
43
- self .on_close ("ws" )
44
- raise WebSocketError (error )
45
-
46
- def on_close (self , ws , * args , ** kwargs ):
47
- self .chatgpt .socket_closed = True
48
-
49
- def on_open (
50
- self ,
51
- ws ,
52
- ):
53
- json_data = json .dumps (self .payload , indent = 4 )
54
- ws .send (json_data )
55
-
56
- def run (
57
- self ,
58
- ):
59
- ws = websocket .WebSocketApp (
60
- self .url ,
61
- on_message = self .on_message ,
62
- on_error = self .on_error ,
63
- on_close = self .on_close ,
64
- on_open = self .on_open ,
65
- )
66
- ws .run_forever (origin = "https://chat.openai.com" )
67
8
68
9
69
10
class ChatGPT :
@@ -127,6 +68,9 @@ def __init__(
127
68
self .stop_sharing_conversation_endpoint = (
128
69
"https://chat.openai.com/backend-api/%(share_id)s"
129
70
)
71
+ self .sentinel_chat_requirements_endpoint : str = (
72
+ "https://chat.openai.com/backend-api/sentinel/chat-requirements"
73
+ )
130
74
self .session .headers ["User-Agent" ] = user_agent
131
75
self .locale = locale
132
76
self .model = model
@@ -139,12 +83,7 @@ def __init__(
139
83
self .__already_init = False
140
84
self .__index = conversation_index
141
85
self .__title_cache = {}
142
- self .last_response_undecoded_chunk : str = ""
143
- self .last_response_chunk : dict = {}
144
- self .loading_chunk : str = ""
145
- self .socket_closed : bool = True
146
- self .trace = trace
147
- self .request_more_times : int = 2
86
+ self .stream_chunk_size = 64
148
87
# self.register_ws =self.session.post("https://chat.openai.com/backend-api/register-websocket")
149
88
# Websocket(self.register_ws.json(),self).run()
150
89
@@ -171,6 +110,13 @@ def current_conversation_id(self):
171
110
def get_current_message_id (self ):
172
111
return self .last_response_metadata .get (2 ).get ("message_id" )
173
112
113
+ def update_sentinel_tokens (self ):
114
+ resp = self .session .post (self .sentinel_chat_requirements_endpoint , json = {})
115
+ resp .raise_for_status ()
116
+ self .session .headers .update (
117
+ {"OpenAI-Sentinel-Chat-Requirements-Token" : resp .json ()["token" ]}
118
+ )
119
+
174
120
def ask (
175
121
self ,
176
122
prompt : str ,
@@ -228,32 +174,28 @@ def ask(
228
174
}
229
175
```
230
176
"""
177
+ self .update_sentinel_tokens ()
231
178
response = self .session .post (
232
179
url = self .conversation_endpoint ,
233
180
json = self .__generate_payload (prompt ),
234
181
timeout = self .timeout ,
235
- stream = False ,
182
+ stream = True ,
236
183
)
237
- response .raise_for_status ()
238
- ws_payload = dict (response .json ())
239
- self .__request_more_count : int = 0
240
-
241
- # out = lambda v:print(json.dumps(dict(v), indent=4))
242
- # out(response.headers)
243
- def for_stream ():
244
-
245
- ws = Websocket (ws_payload , self , self .trace )
246
- t1 = thr (target = ws .run )
247
- t1 .start ()
248
- cached_loading_chunk = self .loading_chunk
249
- cached_last_response = self .last_response .copy ()
250
- while True :
251
- if self .loading_chunk != cached_loading_chunk :
252
- # New chunk loaded
184
+ # response.raise_for_status()
185
+ if (
186
+ response .ok
187
+ and response .headers .get ("content-type" )
188
+ == "text/event-stream; charset=utf-8"
189
+ ):
190
+
191
+ def for_stream ():
192
+ for value in response .iter_lines (
193
+ decode_unicode = True ,
194
+ delimiter = "data:" ,
195
+ chunk_size = self .stream_chunk_size ,
196
+ ):
253
197
try :
254
- value = self .loading_chunk
255
- # print(value)
256
- to_dict = json .loads (value [5 :])
198
+ to_dict = json .loads (value )
257
199
if "is_completion" in to_dict .keys ():
258
200
# Metadata (response)
259
201
self .last_response_metadata [
@@ -269,40 +211,35 @@ def for_stream():
269
211
yield value
270
212
pass
271
213
272
- finally :
273
- cached_loading_chunk = self .loading_chunk
274
-
275
- if self .socket_closed :
276
- t1 .join ()
277
- break
278
-
279
- if (
280
- self .last_response == cached_last_response
281
- or self .last_response ["message" ]["status" ] != "finished_successfully"
282
- ):
283
-
284
- # print(json.dumps(self.last_response, indent=4))
285
- # print("Requesting more body")
286
- # print('=='*40)
287
- t1 .join ()
288
- if self .__request_more_count >= self .request_more_times :
289
- raise MaximumRetrialError (
290
- f"Failed to generate response after { self .request_more_times } attempts"
291
- )
292
-
293
- for value in for_stream ():
294
- yield value
295
-
296
- self .__request_more_count += 1
297
- # else:
298
- # print(print(json.dumps(self.last_response_chunk, indent=4)))
214
+ def for_non_stream ():
215
+ response_to_be_returned = {}
216
+ for value in response .iter_lines (
217
+ decode_unicode = True ,
218
+ delimiter = "data:" ,
219
+ chunk_size = self .stream_chunk_size ,
220
+ ):
221
+ try :
222
+ to_dict = json .loads (value )
223
+ if "is_completion" in to_dict .keys ():
224
+ # Metadata (response)
225
+ self .last_response_metadata [
226
+ 2 if to_dict .get ("is_completion" ) else 1
227
+ ] = to_dict
228
+ continue
229
+ # Only data containing the `feedback body` make it to here
230
+ self .last_response .update (to_dict )
231
+ response_to_be_returned .update (to_dict )
232
+ except json .decoder .JSONDecodeError :
233
+ # Caused by either empty string or [DONE]
234
+ pass
235
+ return response_to_be_returned
299
236
300
- def for_non_stream ():
301
- for _ in for_stream ():
302
- pass
303
- return self .last_response
237
+ return for_stream () if stream else for_non_stream ()
304
238
305
- return for_stream () if stream else for_non_stream ()
239
+ else :
240
+ raise Exception (
241
+ f"Failed to fetch response - ({ response .status_code } , { response .reason } : { response .headers .get ('content-type' )} : { response .text } "
242
+ )
306
243
307
244
def chat (self , prompt : str , stream : bool = False ) -> str :
308
245
"""Interact with ChatGPT on the fly
0 commit comments