Skip to content

Commit f49fdbb

Browse files
committed
Added documentation strings, formatted code using black.
1 parent e27e13b commit f49fdbb

File tree

3 files changed

+119
-63
lines changed

3 files changed

+119
-63
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "chatgptdevfriendly"
7-
version = "0.0.2"
7+
version = "0.0.1"
88
authors = [
99
{ name="Srinivas Kumar R", email="[email protected]" },
1010
]

src/chatgptdevfriendly/v1.py

+114-60
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010
from datetime import datetime
1111

1212

13+
log_formatter = logging.Formatter(
14+
"%(asctime)s | %(levelname)s | %(message)s", "%m-%d-%Y %H:%M:%S"
15+
)
1316

14-
log_formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s',
15-
'%m-%d-%Y %H:%M:%S')
1617

18+
def retry(tries: int = 5, delay: int = 3, backoff: int = 2):
19+
"""Retry wrapper, to retry functions on exceptions and errors.
1720
18-
def retry(tries=5, delay=3, backoff=2):
19-
"""
20-
Retries a function or method until it succeeds or the number of retries is exceeded.
21-
22-
:param tries: the maximum number of times to retry (default 4).
23-
:param delay: the initial delay between retries in seconds (default 3).
24-
:param backoff: the backoff multiplier (e.g. value of 2 will double the delay each retry) (default 2).
21+
Args:
22+
tries (int, optional): Number of retries. Defaults to 5.
23+
delay (int, optional): Time delay in seconds. Defaults to 3.
24+
backoff (int, optional): Backoff between the retries. Defaults to 2.
2525
"""
2626

2727
def deco_retry(func):
@@ -30,7 +30,7 @@ def f_retry(*args, **kwargs):
3030
mtries, mdelay = tries, delay
3131
while mtries >= 0:
3232
try:
33-
#args[0].logger.info("Trying the the function.......")
33+
# args[0].logger.info("Trying the the function.......")
3434
return func(*args, **kwargs)
3535
except (
3636
openai.error.APIError,
@@ -51,7 +51,19 @@ def f_retry(*args, **kwargs):
5151
return deco_retry
5252

5353

54-
def dot_product(list1, list2):
54+
def dot_product(list1: list, list2: list):
55+
"""Computes the dot product of 2 vectors.
56+
57+
Args:
58+
list1 (int): List 1 values
59+
list2 (int): List 2 values
60+
61+
Raises:
62+
ValueError: _description_
63+
64+
Returns:
65+
int: Dot product value of the 2 lists.
66+
"""
5567
# Check if the lengths of the two lists are equal
5668
if len(list1) != len(list2):
5769
raise ValueError("Lists must have same length")
@@ -76,7 +88,14 @@ class ChatGptSmartClient(object):
7688

7789
CONTEXT_TOKEN_LIMIT = 4000
7890

79-
def __init__(self, api_key: str, model: str, log_info: bool=False):
91+
def __init__(self, api_key: str, model: str, log_info: bool = False):
92+
"""Init method to instantiate the client.
93+
94+
Args:
95+
api_key (str): OpenAI api key.
96+
model (str): The model to be used for chat completion, ex: "gpt-3.5-turbo"
97+
log_info (bool, optional): Whether to log stats for each query. Defaults to False.
98+
"""
8099
openai.api_key = api_key
81100

82101
self.instruction_msgs = {
@@ -94,7 +113,7 @@ def __init__(self, api_key: str, model: str, log_info: bool=False):
94113

95114
self.rspid_vs_tottokens_dict = {}
96115

97-
self.total_token_cnt = 0
116+
self.total_token_cnt = 0
98117

99118
self.log_info = log_info
100119
self.logger = logging.getLogger(f"chatgptlogger{time.time()}")
@@ -103,9 +122,19 @@ def __init__(self, api_key: str, model: str, log_info: bool=False):
103122
stdout_handler.setFormatter(log_formatter)
104123
self.logger.addHandler(stdout_handler)
105124

106-
107125
@retry()
108-
def query(self, query: str, w_context=True, add_to_context=True):
126+
def query(self, query: str, w_context: bool = True, add_to_context: bool = True):
127+
"""Wrapper method to make conversation with ChatGPT apis.
128+
129+
Args:
130+
query (str): Query to send to Chatgpt servers.
131+
w_context (bool, optional): Whether to ask this query with previous conversation context. Defaults to True.
132+
add_to_context (bool, optional): Whether to add this query and response to conversation context. Defaults to True.
133+
134+
Returns:
135+
OpenAIJson, int: Response to query and the response id for later rollbck if necessary
136+
"""
137+
109138
# TODO: We coud get the embeddings, cache and further use them to speed up the results.
110139
# self.get_embeddings(query=query)
111140

@@ -120,14 +149,16 @@ def query(self, query: str, w_context=True, add_to_context=True):
120149
msgs.append(query)
121150
else:
122151
msgs = [self.instruction_msgs, query]
123-
152+
124153
start_time = time.time()
125154
response = openai.ChatCompletion.create(model=self.model, messages=msgs)
126155
end_time = time.time()
127156

128157
if self.log_info:
129-
self.logger.info(f"The time taken for this response is : {end_time - start_time} seconds")
130-
#print(f"The time taken for this response is : {end_time - start_time} seconds")
158+
self.logger.info(
159+
f"The time taken for this response is : {end_time - start_time} seconds"
160+
)
161+
# print(f"The time taken for this response is : {end_time - start_time} seconds")
131162

132163
f_resp = response["choices"][0]["message"]
133164
tot_token_cnt = response["usage"]["total_tokens"]
@@ -138,9 +169,10 @@ def query(self, query: str, w_context=True, add_to_context=True):
138169

139170
self.total_token_cnt += tot_token_cnt
140171

141-
142172
if self.log_info:
143-
self.logger.info(f"The total token count currently is {self.total_token_cnt}")
173+
self.logger.info(
174+
f"The total token count currently is {self.total_token_cnt}"
175+
)
144176

145177
if add_to_context:
146178
self.prev_msgs.append(query)
@@ -149,54 +181,62 @@ def query(self, query: str, w_context=True, add_to_context=True):
149181
rsp_id = self.rsp_id
150182
self.rspid_vs_tottokens_dict[self.rsp_id] = tot_token_cnt
151183

152-
#print(self.prev_msgs)
184+
# print(self.prev_msgs)
153185

154186
return f_resp, rsp_id
155187

156-
def erase_history(self):
188+
def erase_history(self) -> None:
189+
"""Removing all previous context."""
157190
self.prev_msgs = [self.instruction_msgs]
158191
self.rsp_id = len(self.prev_msgs)
159192

160-
161193
# This function is used for getting embeddings and hence maybe
162194
# used to speedup the system by caching.
163-
def get_embeddings(self, query: str):
164-
"""_summary_
195+
def get_embeddings(self, query: str) -> None:
196+
"""Calls the OpenAI embeddings API to get the generated embeddings for a query,
197+
can be used later for caching and faster response time.
165198
166199
Args:
167-
query (str): _description_
200+
query (str): Query to be sent to OpenAI servers Chat Apis.
168201
"""
169202
response = openai.Embedding.create(input=query, model="text-embedding-ada-002")
170203
embeddings = response["data"][0]["embedding"]
171204
return embeddings
172205

173-
def rollback_conversation(self, rsp_id):
174-
"""Rollback conversation to the point of a particular response.
206+
def rollback_conversation(self, rsp_id: int) -> None:
207+
"""Rollback conversation to the point of a particular response.
175208
176209
Args:
177-
rsp_id (int): Id number of previous tracked response to roll back to.
210+
rsp_id (int): Id number of previous tracked response to roll back to.
178211
"""
179212
self.prev_msgs = self.prev_msgs[0:rsp_id]
180213
self.rsp_id = len(self.prev_msgs)
181-
182-
def print_metrics(self):
183-
self.logger.info(f"The total tokens used up-till now is: {self.total_token_cnt}")
184-
self.logger.info(f"The average response time is: {sum(self.rsp_time_list)/len(self.rsp_time_list)} sec")
185214

186-
self.plot_rsp_times()
187-
188-
def plot_rsp_times(self):
215+
def print_metrics(self) -> None:
216+
"""Method to log the token usage and response time information."""
217+
self.logger.info(
218+
f"The total tokens used up-till now is: {self.total_token_cnt}"
219+
)
220+
self.logger.info(
221+
f"The average response time is: {sum(self.rsp_time_list)/len(self.rsp_time_list)} sec"
222+
)
189223

190-
formatted_timestamps = [datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') for ts in self.rsp_tstamp_list]
224+
self.plot_rsp_times()
191225

226+
def plot_rsp_times(self) -> None:
227+
"""Method to plot the response times versus the timestamps."""
228+
formatted_timestamps = [
229+
datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%M:%S")
230+
for ts in self.rsp_tstamp_list
231+
]
192232

193233
# Plot the response times against the timestamps
194234
fig, ax = plt.subplots()
195235
ax.plot(formatted_timestamps, self.rsp_time_list)
196236
ax.set_xlabel("Timestamp")
197237
ax.set_ylabel("Response Time (s)")
198238
ax.set_title("ChatGPT API Response Time")
199-
ax.tick_params(axis='x', rotation=45)
239+
ax.tick_params(axis="x", rotation=45)
200240
ax.xaxis.labelpad = 85
201241
ax.yaxis.labelpad = 35
202242

@@ -211,51 +251,65 @@ def plot_rsp_times(self):
211251
# Show the plot
212252
plt.show()
213253

214-
def dump_context_to_a_file(self, filename="context"):
215-
self.dicts_to_jsonl(data_list=self.prev_msgs, filename=filename)
254+
def dump_context_to_a_file(self, filename: str = "context"):
255+
"""Wrapper method to call to dump context to a file.
216256
217-
218-
def dicts_to_jsonl(self, data_list: list, filename: str, compress: bool = True) -> None:
257+
Args:
258+
filename (str, optional): Filename to dump to. Defaults to "context".
219259
"""
220-
Method saves list of dicts into jsonl file.
221-
:param data_list: (list) list of dicts to be stored,
222-
:param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends
223-
.jsonl suffix into the file.
224-
:param compress: (bool) should file be compressed into a gzip archive?
260+
self.dicts_to_jsonl(data_list=self.prev_msgs, filename=filename)
261+
262+
def dicts_to_jsonl(self, data_list: list, filename: str) -> None:
263+
"""Method saves context list of dicts to .jsonl format.
264+
265+
Args:
266+
data_list (list): List of dicts of previous conversations
267+
filename (str): Filenam to save to
225268
"""
226-
sjsonl = '.jsonl'
269+
sjsonl = ".jsonl"
227270

228271
# Check filename
229272
if not filename.endswith(sjsonl):
230273
filename = filename + sjsonl
231274

232275
# Save data
233-
with open(filename, 'w') as out:
276+
with open(filename, "w") as out:
234277
for ddict in data_list:
235-
jout = json.dumps(ddict) + '\n'
278+
jout = json.dumps(ddict) + "\n"
236279
out.write(jout)
237-
238-
def load_context_from_a_file(self, filename):
239280

240-
sjsonl = '.jsonl'
281+
def load_context_from_a_file(self, filename: str):
282+
"""Fills up the self.prev_msgs i.e context from a dumped file.
283+
284+
Args:
285+
filename (str): Filename should be a .jsonl file.
286+
287+
Returns:
288+
list: List containing the context information
289+
"""
290+
sjsonl = ".jsonl"
241291

242292
# Check filename
243293
if not filename.endswith(sjsonl):
244294
filename = filename + sjsonl
245-
295+
246296
self.prev_msgs = []
247-
248-
with open(filename, encoding='utf-8') as json_file:
297+
298+
with open(filename, encoding="utf-8") as json_file:
249299
for line in json_file.readlines():
250300
self.prev_msgs.append(line)
251301

252302
return self.prev_msgs
253303

254304
def trim_conversation(self):
305+
"""Method to trim, context generatd till now to below the token limit.
306+
The queries are removed in FIFO manner until the token count goes below the limit.
307+
"""
255308
while sum(self.total_token_cnt_list) >= self.CONTEXT_TOKEN_LIMIT:
256-
257-
del self.total_token_cnt_list[0]
309+
del self.total_token_cnt_list[0]
258310
del self.prev_msgs[1]
259311
del self.prev_msgs[2]
260-
261-
self.logger.info(f"Trimmed the context list to length: {sum(self.total_token_cnt_list)}")
312+
313+
self.logger.info(
314+
f"Trimmed the context list to length: {sum(self.total_token_cnt_list)}"
315+
)

tests/test_main.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
for _ in range(10):
1010
chatgptsmtclient.query("List the top 10 upcoming startups in India?")
11-
chatgptsmtclient.query("Ok thanks, can you giv me the valuation of these startups in tabuar format")
11+
chatgptsmtclient.query(
12+
"Ok thanks, can you giv me the valuation of these startups in tabuar format"
13+
)
1214

1315
chatgptsmtclient.dump_context_to_a_file("context")
1416
chatgptsmtclient.load_context_from_a_file("context")
1517

16-
chatgptsmtclient.print_metrics()
18+
chatgptsmtclient.print_metrics()

0 commit comments

Comments
 (0)