Skip to content

Commit 7705ee1

Browse files
Improved token counting
1 parent 91c96c7 commit 7705ee1

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ long_description_content_type = text/markdown
66
install_requires =
77
pyyaml
88
mixin-python
9+
tiktoken
910
openai>=0.27.0
1011

1112
[options.extras_require]

src/chatgpt_openai.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import shelve
88
from datetime import datetime
99
from dataclasses import dataclass
10+
import tiktoken
1011
import openai
1112
from pymixin import log
1213

@@ -24,7 +25,8 @@ class Message:
2425
g_conversations = shelve.open(f".db/conversations")
2526

2627
default_role = 'You are a helpful assistant'
27-
max_prompt_token = 3000.0
28+
max_prompt_token = 3000
29+
gpt_encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
2830

2931
class ChatGPTBot:
3032
def __init__(self, api_key: str, stream=True):
@@ -56,6 +58,10 @@ def get_parent_messsage(self, conversation_id: str, message_id: str) -> Optional
5658
return g_conversations[key]
5759
return None
5860

61+
def count_tokens(self, message) -> int:
62+
tokens = gpt_encoding.encode(message)
63+
return len(tokens)
64+
5965
def add_messsage(self, conversation_id: str, query: str, reply: str) -> str:
6066

6167
message_id = str(uuid.uuid4())
@@ -119,10 +125,7 @@ def generate_prompt(self, conversation_id: str, message: str) -> Optional[List[D
119125
context_messages.append({"role": "user", "content": message})
120126
return context_messages
121127

122-
tokens_count = len(message.split()) / 0.75
123-
for ch in message:
124-
if ord(ch) > 256:
125-
tokens_count += 2
128+
tokens_count = self.count_tokens(message)
126129

127130
if tokens_count > max_prompt_token:
128131
return None
@@ -132,11 +135,8 @@ def generate_prompt(self, conversation_id: str, message: str) -> Optional[List[D
132135
parent_message = self.get_parent_messsage(conversation_id, parent_message_id)
133136
assert parent_message
134137
contents = ' '.join((parent_message.message, parent_message.completion))
135-
current_tokens_count = len(contents.split()) / 0.75
138+
current_tokens_count = self.count_tokens(contents)
136139
# count unicode characters tokens
137-
for ch in contents:
138-
if ord(ch) > 256:
139-
current_tokens_count += 2
140140
if tokens_count + current_tokens_count > max_prompt_token:
141141
break
142142
tokens_count += current_tokens_count
@@ -218,6 +218,10 @@ async def _send_message_stream(self, conversation_id: str, message: str):
218218
self.users[conversation_id] = True
219219

220220
prompt = self.generate_prompt(conversation_id, message)
221+
if not prompt:
222+
yield '[BEGIN]'
223+
yield 'oops, something went wrong, please try to reduce your worlds.'
224+
return
221225
start_time = time.time()
222226
try:
223227
yield '[BEGIN]'

0 commit comments

Comments
 (0)