77import shelve
88from datetime import datetime
99from dataclasses import dataclass
10+ import tiktoken
1011import openai
1112from pymixin import log
1213
@@ -24,7 +25,8 @@ class Message:
2425g_conversations = shelve .open (f".db/conversations" )
2526
2627default_role = 'You are a helpful assistant'
27- max_prompt_token = 3000.0
28+ max_prompt_token = 3000
29+ gpt_encoding = tiktoken .encoding_for_model ("gpt-3.5-turbo" )
2830
2931class ChatGPTBot :
3032 def __init__ (self , api_key : str , stream = True ):
@@ -56,6 +58,10 @@ def get_parent_messsage(self, conversation_id: str, message_id: str) -> Optional
5658 return g_conversations [key ]
5759 return None
5860
61+ def count_tokens (self , message ) -> int :
62+ tokens = gpt_encoding .encode (message )
63+ return len (tokens )
64+
5965 def add_messsage (self , conversation_id : str , query : str , reply : str ) -> str :
6066
6167 message_id = str (uuid .uuid4 ())
@@ -119,10 +125,7 @@ def generate_prompt(self, conversation_id: str, message: str) -> Optional[List[D
119125 context_messages .append ({"role" : "user" , "content" : message })
120126 return context_messages
121127
122- tokens_count = len (message .split ()) / 0.75
123- for ch in message :
124- if ord (ch ) > 256 :
125- tokens_count += 2
128+ tokens_count = self .count_tokens (message )
126129
127130 if tokens_count > max_prompt_token :
128131 return None
@@ -132,11 +135,8 @@ def generate_prompt(self, conversation_id: str, message: str) -> Optional[List[D
132135 parent_message = self .get_parent_messsage (conversation_id , parent_message_id )
133136 assert parent_message
134137 contents = ' ' .join ((parent_message .message , parent_message .completion ))
135- current_tokens_count = len ( contents . split ()) / 0.75
138+ current_tokens_count = self . count_tokens ( contents )
136139 # count unicode characters tokens
137- for ch in contents :
138- if ord (ch ) > 256 :
139- current_tokens_count += 2
140140 if tokens_count + current_tokens_count > max_prompt_token :
141141 break
142142 tokens_count += current_tokens_count
@@ -218,6 +218,10 @@ async def _send_message_stream(self, conversation_id: str, message: str):
218218 self .users [conversation_id ] = True
219219
220220 prompt = self .generate_prompt (conversation_id , message )
221+ if not prompt :
222+ yield '[BEGIN]'
223+ yield 'oops, something went wrong, please try to reduce your worlds.'
224+ return
221225 start_time = time .time ()
222226 try :
223227 yield '[BEGIN]'
0 commit comments