10
10
from datetime import datetime
11
11
12
12
13
+ log_formatter = logging .Formatter (
14
+ "%(asctime)s | %(levelname)s | %(message)s" , "%m-%d-%Y %H:%M:%S"
15
+ )
13
16
14
- log_formatter = logging .Formatter ('%(asctime)s | %(levelname)s | %(message)s' ,
15
- '%m-%d-%Y %H:%M:%S' )
16
17
18
+ def retry (tries : int = 5 , delay : int = 3 , backoff : int = 2 ):
19
+ """Retry wrapper, to retry functions on exceptions and errors.
17
20
18
- def retry (tries = 5 , delay = 3 , backoff = 2 ):
19
- """
20
- Retries a function or method until it succeeds or the number of retries is exceeded.
21
-
22
- :param tries: the maximum number of times to retry (default 4).
23
- :param delay: the initial delay between retries in seconds (default 3).
24
- :param backoff: the backoff multiplier (e.g. value of 2 will double the delay each retry) (default 2).
21
+ Args:
22
+ tries (int, optional): Number of retries. Defaults to 5.
23
+ delay (int, optional): Time delay in seconds. Defaults to 3.
24
+ backoff (int, optional): Backoff between the retries. Defaults to 2.
25
25
"""
26
26
27
27
def deco_retry (func ):
@@ -30,7 +30,7 @@ def f_retry(*args, **kwargs):
30
30
mtries , mdelay = tries , delay
31
31
while mtries >= 0 :
32
32
try :
33
- #args[0].logger.info("Trying the the function.......")
33
+ # args[0].logger.info("Trying the the function.......")
34
34
return func (* args , ** kwargs )
35
35
except (
36
36
openai .error .APIError ,
@@ -51,7 +51,19 @@ def f_retry(*args, **kwargs):
51
51
return deco_retry
52
52
53
53
54
- def dot_product (list1 , list2 ):
54
+ def dot_product (list1 : list , list2 : list ):
55
+ """Computes the dot product of 2 vectors.
56
+
57
+ Args:
58
+ list1 (int): List 1 values
59
+ list2 (int): List 2 values
60
+
61
+ Raises:
62
+ ValueError: _description_
63
+
64
+ Returns:
65
+ int: Dot product value of the 2 lists.
66
+ """
55
67
# Check if the lengths of the two lists are equal
56
68
if len (list1 ) != len (list2 ):
57
69
raise ValueError ("Lists must have same length" )
@@ -76,7 +88,14 @@ class ChatGptSmartClient(object):
76
88
77
89
CONTEXT_TOKEN_LIMIT = 4000
78
90
79
- def __init__ (self , api_key : str , model : str , log_info : bool = False ):
91
+ def __init__ (self , api_key : str , model : str , log_info : bool = False ):
92
+ """Init method to instantiate the client.
93
+
94
+ Args:
95
+ api_key (str): OpenAI api key.
96
+ model (str): The model to be used for chat completion, ex: "gpt-3.5-turbo"
97
+ log_info (bool, optional): Whether to log stats for each query. Defaults to False.
98
+ """
80
99
openai .api_key = api_key
81
100
82
101
self .instruction_msgs = {
@@ -94,7 +113,7 @@ def __init__(self, api_key: str, model: str, log_info: bool=False):
94
113
95
114
self .rspid_vs_tottokens_dict = {}
96
115
97
- self .total_token_cnt = 0
116
+ self .total_token_cnt = 0
98
117
99
118
self .log_info = log_info
100
119
self .logger = logging .getLogger (f"chatgptlogger{ time .time ()} " )
@@ -103,9 +122,19 @@ def __init__(self, api_key: str, model: str, log_info: bool=False):
103
122
stdout_handler .setFormatter (log_formatter )
104
123
self .logger .addHandler (stdout_handler )
105
124
106
-
107
125
@retry ()
108
- def query (self , query : str , w_context = True , add_to_context = True ):
126
+ def query (self , query : str , w_context : bool = True , add_to_context : bool = True ):
127
+ """Wrapper method to make conversation with ChatGPT apis.
128
+
129
+ Args:
130
+ query (str): Query to send to Chatgpt servers.
131
+ w_context (bool, optional): Whether to ask this query with previous conversation context. Defaults to True.
132
+ add_to_context (bool, optional): Whether to add this query and response to conversation context. Defaults to True.
133
+
134
+ Returns:
135
+ OpenAIJson, int: Response to query and the response id for later rollbck if necessary
136
+ """
137
+
109
138
# TODO: We coud get the embeddings, cache and further use them to speed up the results.
110
139
# self.get_embeddings(query=query)
111
140
@@ -120,14 +149,16 @@ def query(self, query: str, w_context=True, add_to_context=True):
120
149
msgs .append (query )
121
150
else :
122
151
msgs = [self .instruction_msgs , query ]
123
-
152
+
124
153
start_time = time .time ()
125
154
response = openai .ChatCompletion .create (model = self .model , messages = msgs )
126
155
end_time = time .time ()
127
156
128
157
if self .log_info :
129
- self .logger .info (f"The time taken for this response is : { end_time - start_time } seconds" )
130
- #print(f"The time taken for this response is : {end_time - start_time} seconds")
158
+ self .logger .info (
159
+ f"The time taken for this response is : { end_time - start_time } seconds"
160
+ )
161
+ # print(f"The time taken for this response is : {end_time - start_time} seconds")
131
162
132
163
f_resp = response ["choices" ][0 ]["message" ]
133
164
tot_token_cnt = response ["usage" ]["total_tokens" ]
@@ -138,9 +169,10 @@ def query(self, query: str, w_context=True, add_to_context=True):
138
169
139
170
self .total_token_cnt += tot_token_cnt
140
171
141
-
142
172
if self .log_info :
143
- self .logger .info (f"The total token count currently is { self .total_token_cnt } " )
173
+ self .logger .info (
174
+ f"The total token count currently is { self .total_token_cnt } "
175
+ )
144
176
145
177
if add_to_context :
146
178
self .prev_msgs .append (query )
@@ -149,54 +181,62 @@ def query(self, query: str, w_context=True, add_to_context=True):
149
181
rsp_id = self .rsp_id
150
182
self .rspid_vs_tottokens_dict [self .rsp_id ] = tot_token_cnt
151
183
152
- #print(self.prev_msgs)
184
+ # print(self.prev_msgs)
153
185
154
186
return f_resp , rsp_id
155
187
156
- def erase_history (self ):
188
+ def erase_history (self ) -> None :
189
+ """Removing all previous context."""
157
190
self .prev_msgs = [self .instruction_msgs ]
158
191
self .rsp_id = len (self .prev_msgs )
159
192
160
-
161
193
# This function is used for getting embeddings and hence maybe
162
194
# used to speedup the system by caching.
163
- def get_embeddings (self , query : str ):
164
- """_summary_
195
+ def get_embeddings (self , query : str ) -> None :
196
+ """Calls the OpenAI embeddings API to get the generated embeddings for a query,
197
+ can be used later for caching and faster response time.
165
198
166
199
Args:
167
- query (str): _description_
200
+ query (str): Query to be sent to OpenAI servers Chat Apis.
168
201
"""
169
202
response = openai .Embedding .create (input = query , model = "text-embedding-ada-002" )
170
203
embeddings = response ["data" ][0 ]["embedding" ]
171
204
return embeddings
172
205
173
- def rollback_conversation (self , rsp_id ) :
174
- """Rollback conversation to the point of a particular response.
206
+ def rollback_conversation (self , rsp_id : int ) -> None :
207
+ """Rollback conversation to the point of a particular response.
175
208
176
209
Args:
177
- rsp_id (int): Id number of previous tracked response to roll back to.
210
+ rsp_id (int): Id number of previous tracked response to roll back to.
178
211
"""
179
212
self .prev_msgs = self .prev_msgs [0 :rsp_id ]
180
213
self .rsp_id = len (self .prev_msgs )
181
-
182
- def print_metrics (self ):
183
- self .logger .info (f"The total tokens used up-till now is: { self .total_token_cnt } " )
184
- self .logger .info (f"The average response time is: { sum (self .rsp_time_list )/ len (self .rsp_time_list )} sec" )
185
214
186
- self .plot_rsp_times ()
187
-
188
- def plot_rsp_times (self ):
215
+ def print_metrics (self ) -> None :
216
+ """Method to log the token usage and response time information."""
217
+ self .logger .info (
218
+ f"The total tokens used up-till now is: { self .total_token_cnt } "
219
+ )
220
+ self .logger .info (
221
+ f"The average response time is: { sum (self .rsp_time_list )/ len (self .rsp_time_list )} sec"
222
+ )
189
223
190
- formatted_timestamps = [ datetime . fromtimestamp ( ts ). strftime ( '%Y-%m-%d %H:%M:%S' ) for ts in self . rsp_tstamp_list ]
224
+ self . plot_rsp_times ()
191
225
226
+ def plot_rsp_times (self ) -> None :
227
+ """Method to plot the response times versus the timestamps."""
228
+ formatted_timestamps = [
229
+ datetime .fromtimestamp (ts ).strftime ("%Y-%m-%d %H:%M:%S" )
230
+ for ts in self .rsp_tstamp_list
231
+ ]
192
232
193
233
# Plot the response times against the timestamps
194
234
fig , ax = plt .subplots ()
195
235
ax .plot (formatted_timestamps , self .rsp_time_list )
196
236
ax .set_xlabel ("Timestamp" )
197
237
ax .set_ylabel ("Response Time (s)" )
198
238
ax .set_title ("ChatGPT API Response Time" )
199
- ax .tick_params (axis = 'x' , rotation = 45 )
239
+ ax .tick_params (axis = "x" , rotation = 45 )
200
240
ax .xaxis .labelpad = 85
201
241
ax .yaxis .labelpad = 35
202
242
@@ -211,51 +251,65 @@ def plot_rsp_times(self):
211
251
# Show the plot
212
252
plt .show ()
213
253
214
- def dump_context_to_a_file (self , filename = "context" ):
215
- self . dicts_to_jsonl ( data_list = self . prev_msgs , filename = filename )
254
+ def dump_context_to_a_file (self , filename : str = "context" ):
255
+ """Wrapper method to call to dump context to a file.
216
256
217
-
218
- def dicts_to_jsonl ( self , data_list : list , filename : str , compress : bool = True ) -> None :
257
+ Args:
258
+ filename ( str, optional): Filename to dump to. Defaults to "context".
219
259
"""
220
- Method saves list of dicts into jsonl file.
221
- :param data_list: (list) list of dicts to be stored,
222
- :param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends
223
- .jsonl suffix into the file.
224
- :param compress: (bool) should file be compressed into a gzip archive?
260
+ self .dicts_to_jsonl (data_list = self .prev_msgs , filename = filename )
261
+
262
+ def dicts_to_jsonl (self , data_list : list , filename : str ) -> None :
263
+ """Method saves context list of dicts to .jsonl format.
264
+
265
+ Args:
266
+ data_list (list): List of dicts of previous conversations
267
+ filename (str): Filenam to save to
225
268
"""
226
- sjsonl = ' .jsonl'
269
+ sjsonl = " .jsonl"
227
270
228
271
# Check filename
229
272
if not filename .endswith (sjsonl ):
230
273
filename = filename + sjsonl
231
274
232
275
# Save data
233
- with open (filename , 'w' ) as out :
276
+ with open (filename , "w" ) as out :
234
277
for ddict in data_list :
235
- jout = json .dumps (ddict ) + ' \n '
278
+ jout = json .dumps (ddict ) + " \n "
236
279
out .write (jout )
237
-
238
- def load_context_from_a_file (self , filename ):
239
280
240
- sjsonl = '.jsonl'
281
+ def load_context_from_a_file (self , filename : str ):
282
+ """Fills up the self.prev_msgs i.e context from a dumped file.
283
+
284
+ Args:
285
+ filename (str): Filename should be a .jsonl file.
286
+
287
+ Returns:
288
+ list: List containing the context information
289
+ """
290
+ sjsonl = ".jsonl"
241
291
242
292
# Check filename
243
293
if not filename .endswith (sjsonl ):
244
294
filename = filename + sjsonl
245
-
295
+
246
296
self .prev_msgs = []
247
-
248
- with open (filename , encoding = ' utf-8' ) as json_file :
297
+
298
+ with open (filename , encoding = " utf-8" ) as json_file :
249
299
for line in json_file .readlines ():
250
300
self .prev_msgs .append (line )
251
301
252
302
return self .prev_msgs
253
303
254
304
def trim_conversation (self ):
305
+ """Method to trim, context generatd till now to below the token limit.
306
+ The queries are removed in FIFO manner until the token count goes below the limit.
307
+ """
255
308
while sum (self .total_token_cnt_list ) >= self .CONTEXT_TOKEN_LIMIT :
256
-
257
- del self .total_token_cnt_list [0 ]
309
+ del self .total_token_cnt_list [0 ]
258
310
del self .prev_msgs [1 ]
259
311
del self .prev_msgs [2 ]
260
-
261
- self .logger .info (f"Trimmed the context list to length: { sum (self .total_token_cnt_list )} " )
312
+
313
+ self .logger .info (
314
+ f"Trimmed the context list to length: { sum (self .total_token_cnt_list )} "
315
+ )
0 commit comments