Skip to content

Commit 696c50c

Browse files
committed
Added context trimming.
1 parent ac5f925 commit 696c50c

File tree

6 files changed

+72
-12
lines changed

6 files changed

+72
-12
lines changed

README.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,15 @@ response = chatgpt_client.query(prompt, w_context=True, add_to_context=False)
2727
```
2828

2929
## Features
30-
We currently have the features of
30+
31+
32+
- [x] Save Conversations to a file
33+
- [x] Resume conversations by loading context from a file.
34+
- [x] Retry logic in case of API failures.
35+
- [x] Regular Trimming of context to 4000 tokens so that limit of 4097 is not breached.
36+
- [x] Total token and token vs time metric.
37+
38+
3139
- Retries: This is incase of failures like connection based request exceptions, API errors.
3240
```
3341
(openai) C:\Users\Srinivas\OneDrive\Desktop\StartupSearchGPT\tests>python test_main.py
@@ -39,6 +47,9 @@ We currently have the features of
3947
Retrying after 12 seconds...
4048
Error occurred: API error , please try later
4149
```
50+
- Context trimming: Context is trimmed as needed when the limit breaches 4000 tokens.
51+
![Trimming and printing metrics](printed_metrics.png)
52+
4253
- Tracking metrics such as average time per response and total token usage.
4354
```
4455
04-10-2023 10:26:44 | INFO | The time taken for this response is : 7.85 seconds

printed_metrics.png

57.6 KB
Loading

response_times.png

16.6 KB
Loading

src/main.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class ChatGptSmartClient(object):
7373
read more on that here: https://platform.openai.com/docs/guides/chat .
7474
"""
7575

76+
CONTEXT_TOKEN_LIMIT = 4000
77+
7678
def __init__(self, api_key: str, model: str, log_info: bool=False):
7779
openai.api_key = api_key
7880

@@ -89,6 +91,10 @@ def __init__(self, api_key: str, model: str, log_info: bool=False):
8991
self.rsp_tstamp_list = []
9092
self.total_token_cnt_list = []
9193

94+
self.rspid_vs_tottokens_dict = {}
95+
96+
self.total_token_cnt = 0
97+
9298
self.log_info = log_info
9399
self.logger = logging.getLogger("chatgptlogger")
94100
self.logger.setLevel(logging.INFO)
@@ -105,6 +111,9 @@ def query(self, query: str, w_context=True, add_to_context=True):
105111
query = {"role": "user", "content": query}
106112
rsp_id = None
107113

114+
if sum(self.total_token_cnt_list) >= self.CONTEXT_TOKEN_LIMIT:
115+
self.trim_conversation()
116+
108117
if w_context:
109118
msgs = self.prev_msgs[:]
110119
msgs.append(query)
@@ -126,13 +135,20 @@ def query(self, query: str, w_context=True, add_to_context=True):
126135
self.rsp_time_list.append(end_time - start_time)
127136
self.total_token_cnt_list.append(tot_token_cnt)
128137

138+
self.total_token_cnt += tot_token_cnt
139+
140+
129141
if self.log_info:
130-
self.logger.info(f"The total token count currently is {sum(self.total_token_cnt_list)}")
142+
self.logger.info(f"The total token count currently is {self.total_token_cnt}")
131143

132144
if add_to_context:
145+
self.prev_msgs.append(query)
133146
self.prev_msgs.append(f_resp)
134147
self.rsp_id += 1
135148
rsp_id = self.rsp_id
149+
self.rspid_vs_tottokens_dict[self.rsp_id] = tot_token_cnt
150+
151+
#print(self.prev_msgs)
136152

137153
return f_resp, rsp_id
138154

@@ -163,7 +179,7 @@ def rollback_conversation(self, rsp_id):
163179
self.rsp_id = len(self.prev_msgs)
164180

165181
def print_metrics(self):
166-
self.logger.info(f"The total tokens used up-till now is: {sum(self.total_token_cnt_list)}")
182+
self.logger.info(f"The total tokens used up-till now is: {self.total_token_cnt}")
167183
self.logger.info(f"The average response time is: {sum(self.rsp_time_list)/len(self.rsp_time_list)} sec")
168184

169185
self.plot_rsp_times()
@@ -218,7 +234,7 @@ def dicts_to_jsonl(self, data_list: list, filename: str, compress: bool = True)
218234
jout = json.dumps(ddict) + '\n'
219235
out.write(jout)
220236

221-
def load_context_from_a_file(self, file):
237+
def load_context_from_a_file(self, filename):
222238

223239
sjsonl = '.jsonl'
224240

@@ -228,8 +244,17 @@ def load_context_from_a_file(self, file):
228244

229245
self.prev_msgs = []
230246

231-
with open(file, encoding='utf-8') as json_file:
247+
with open(filename, encoding='utf-8') as json_file:
232248
for line in json_file.readlines():
233249
self.prev_msgs.append(line)
234250

235-
return self.prev_msgs
251+
return self.prev_msgs
252+
253+
def trim_conversation(self):
254+
while sum(self.total_token_cnt_list) >= self.CONTEXT_TOKEN_LIMIT:
255+
256+
del self.total_token_cnt_list[0]
257+
del self.prev_msgs[1]
258+
del self.prev_msgs[2]
259+
260+
print(f"Trimmed the context list to length: {sum(self.total_token_cnt_list)}")

tests/main.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class ChatGptSmartClient(object):
7373
read more on that here: https://platform.openai.com/docs/guides/chat .
7474
"""
7575

76+
CONTEXT_TOKEN_LIMIT = 4000
77+
7678
def __init__(self, api_key: str, model: str, log_info: bool=False):
7779
openai.api_key = api_key
7880

@@ -89,6 +91,10 @@ def __init__(self, api_key: str, model: str, log_info: bool=False):
8991
self.rsp_tstamp_list = []
9092
self.total_token_cnt_list = []
9193

94+
self.rspid_vs_tottokens_dict = {}
95+
96+
self.total_token_cnt = 0
97+
9298
self.log_info = log_info
9399
self.logger = logging.getLogger("chatgptlogger")
94100
self.logger.setLevel(logging.INFO)
@@ -105,6 +111,9 @@ def query(self, query: str, w_context=True, add_to_context=True):
105111
query = {"role": "user", "content": query}
106112
rsp_id = None
107113

114+
if sum(self.total_token_cnt_list) >= self.CONTEXT_TOKEN_LIMIT:
115+
self.trim_conversation()
116+
108117
if w_context:
109118
msgs = self.prev_msgs[:]
110119
msgs.append(query)
@@ -126,13 +135,20 @@ def query(self, query: str, w_context=True, add_to_context=True):
126135
self.rsp_time_list.append(end_time - start_time)
127136
self.total_token_cnt_list.append(tot_token_cnt)
128137

138+
self.total_token_cnt += tot_token_cnt
139+
140+
129141
if self.log_info:
130-
self.logger.info(f"The total token count currently is {sum(self.total_token_cnt_list)}")
142+
self.logger.info(f"The total token count currently is {self.total_token_cnt}")
131143

132144
if add_to_context:
145+
self.prev_msgs.append(query)
133146
self.prev_msgs.append(f_resp)
134147
self.rsp_id += 1
135148
rsp_id = self.rsp_id
149+
self.rspid_vs_tottokens_dict[self.rsp_id] = tot_token_cnt
150+
151+
#print(self.prev_msgs)
136152

137153
return f_resp, rsp_id
138154

@@ -163,7 +179,7 @@ def rollback_conversation(self, rsp_id):
163179
self.rsp_id = len(self.prev_msgs)
164180

165181
def print_metrics(self):
166-
self.logger.info(f"The total tokens used up-till now is: {sum(self.total_token_cnt_list)}")
182+
self.logger.info(f"The total tokens used up-till now is: {self.total_token_cnt}")
167183
self.logger.info(f"The average response time is: {sum(self.rsp_time_list)/len(self.rsp_time_list)} sec")
168184

169185
self.plot_rsp_times()
@@ -188,7 +204,6 @@ def plot_rsp_times(self):
188204

189205
plt.xticks(rotation=45, fontsize=6)
190206

191-
192207
# Save the figure as a PNG file
193208
plt.savefig("response_times.png")
194209

@@ -233,4 +248,13 @@ def load_context_from_a_file(self, filename):
233248
for line in json_file.readlines():
234249
self.prev_msgs.append(line)
235250

236-
return self.prev_msgs
251+
return self.prev_msgs
252+
253+
def trim_conversation(self):
254+
while sum(self.total_token_cnt_list) >= self.CONTEXT_TOKEN_LIMIT:
255+
256+
del self.total_token_cnt_list[0]
257+
del self.prev_msgs[1]
258+
del self.prev_msgs[2]
259+
260+
print(f"Trimmed the context list to length: {sum(self.total_token_cnt_list)}")

tests/test_main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
api_key = "your_api_key"
1+
api_key = "api_key"
22
model_name = "gpt-3.5-turbo"
33

44
from main import ChatGptSmartClient
55

66

77
chatgptsmtclient = ChatGptSmartClient(api_key=api_key, model=model_name, log_info=True)
88

9-
for _ in range(2):
9+
for _ in range(10):
1010
chatgptsmtclient.query("List the top 10 upcoming startups in India?")
1111
chatgptsmtclient.query("Ok thanks, can you giv me the valuation of these startups in tabuar format")
1212

0 commit comments

Comments
 (0)