1
1
import json
2
2
import os
3
+ import sys
3
4
import time
4
- from dataclasses import dataclass
5
- from typing import Optional
5
+ import traceback
6
+ from dataclasses import dataclass , field
7
+ from typing import List , Optional
6
8
7
9
import aiohttp
8
10
from tqdm .asyncio import tqdm
@@ -26,8 +28,11 @@ class RequestFuncOutput:
26
28
generated_text : str = ""
27
29
success : bool = False
28
30
latency : float = 0
29
- ttft : float = 0
31
+ ttft : float = 0 # Time to first token
32
+ itl : List [float ] = field (
33
+ default_factory = list ) # List of inter-token latencies
30
34
prompt_len : int = 0
35
+ error : str = ""
31
36
32
37
33
38
async def async_request_tgi (
@@ -55,71 +60,38 @@ async def async_request_tgi(
55
60
56
61
ttft = 0
57
62
st = time .perf_counter ()
63
+ most_recent_timestamp = st
58
64
try :
59
65
async with session .post (url = api_url , json = payload ) as response :
60
66
if response .status == 200 :
61
- async for data in response .content .iter_any ():
62
- if ttft == 0 :
63
- ttft = time .perf_counter () - st
64
- output .ttft = ttft
65
- output .latency = time .perf_counter () - st
66
-
67
- body = remove_prefix (data .decode ("utf-8" ), "data:" )
68
- output .generated_text = json .loads (body )["generated_text" ]
69
- output .success = True
70
- else :
71
- output .success = False
72
- except (aiohttp .ClientOSError , aiohttp .ServerDisconnectedError ):
73
- output .success = False
74
-
75
- if pbar :
76
- pbar .update (1 )
77
- return output
78
-
79
-
80
- async def async_request_vllm (
81
- request_func_input : RequestFuncInput ,
82
- pbar : Optional [tqdm ] = None ,
83
- ) -> RequestFuncOutput :
84
- api_url = request_func_input .api_url
85
- assert api_url .endswith ("generate" )
67
+ async for chunk in response .content :
68
+ chunk = chunk .strip ()
69
+ if not chunk :
70
+ continue
86
71
87
- async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
88
- payload = {
89
- "prompt" : request_func_input .prompt ,
90
- "n" : 1 ,
91
- "best_of" : request_func_input .best_of ,
92
- "use_beam_search" : request_func_input .use_beam_search ,
93
- "temperature" : 0.0 if request_func_input .use_beam_search else 1.0 ,
94
- "top_p" : 1.0 ,
95
- "max_tokens" : request_func_input .output_len ,
96
- "ignore_eos" : True ,
97
- "stream" : True ,
98
- }
99
- output = RequestFuncOutput ()
100
- output .prompt_len = request_func_input .prompt_len
72
+ chunk = remove_prefix (chunk .decode ("utf-8" ), "data:" )
101
73
102
- ttft = 0
103
- st = time .perf_counter ()
104
- try :
105
- async with session .post (url = api_url , json = payload ) as response :
106
- if response .status == 200 :
107
- async for data in response .content .iter_any ():
74
+ data = json .loads (chunk )
75
+ timestamp = time .perf_counter ()
76
+ # First token
108
77
if ttft == 0 :
109
78
ttft = time .perf_counter () - st
110
79
output .ttft = ttft
111
- output .latency = time .perf_counter () - st
112
80
113
- # When streaming, '\0' is appended to the end of response.
114
- body = data .decode ("utf-8" ).strip ("\0 " )
115
- output .generated_text = json .loads (
116
- body )["text" ][0 ][len (request_func_input .prompt ):]
117
- output .success = True
81
+ # Decoding phase
82
+ else :
83
+ output .itl .append (timestamp -
84
+ most_recent_timestamp )
118
85
119
- else :
120
- output .success = False
121
- except (aiohttp .ClientOSError , aiohttp .ServerDisconnectedError ):
86
+ most_recent_timestamp = timestamp
87
+
88
+ output .latency = most_recent_timestamp - st
89
+ output .success = True
90
+ output .generated_text = data ["generated_text" ]
91
+ except Exception :
122
92
output .success = False
93
+ exc_info = sys .exc_info ()
94
+ output .error = "" .join (traceback .format_exception (* exc_info ))
123
95
124
96
if pbar :
125
97
pbar .update (1 )
@@ -146,26 +118,45 @@ async def async_request_trt_llm(
146
118
}
147
119
output = RequestFuncOutput ()
148
120
output .prompt_len = request_func_input .prompt_len
149
- ttft = 0
150
121
122
+ ttft = 0
151
123
st = time .perf_counter ()
124
+ most_recent_timestamp = st
152
125
try :
153
- async with session .post (url = api_url , json = payload ) as resp :
154
- if resp .status == 200 :
155
- async for data in resp .content .iter_any ():
126
+ async with session .post (url = api_url , json = payload ) as response :
127
+ if response .status == 200 :
128
+ async for chunk in response .content :
129
+ chunk = chunk .strip ()
130
+ if not chunk :
131
+ continue
132
+
133
+ chunk = remove_prefix (chunk .decode ("utf-8" ), "data:" )
134
+
135
+ data = json .loads (chunk )
136
+ timestamp = time .perf_counter ()
137
+ # First token
156
138
if ttft == 0 :
157
139
ttft = time .perf_counter () - st
158
140
output .ttft = ttft
159
- output .latency = time .perf_counter () - st
160
141
161
- body = remove_prefix (data .decode ("utf-8" ), "data:" )
162
- output .generated_text = json .loads (body )["text_output" ]
142
+ # Decoding phase
143
+ else :
144
+ output .itl .append (timestamp -
145
+ most_recent_timestamp )
146
+
147
+ most_recent_timestamp = timestamp
148
+
149
+ output .latency = most_recent_timestamp - st
150
+ output .generated_text = json .loads (data )["text_output" ]
163
151
output .success = True
164
152
165
153
else :
154
+ output .error = response .reason
166
155
output .success = False
167
- except ( aiohttp . ClientOSError , aiohttp . ServerDisconnectedError ) :
156
+ except Exception :
168
157
output .success = False
158
+ exc_info = sys .exc_info ()
159
+ output .error = "" .join (traceback .format_exception (* exc_info ))
169
160
170
161
if pbar :
171
162
pbar .update (1 )
@@ -181,35 +172,35 @@ async def async_request_deepspeed_mii(
181
172
assert not request_func_input .use_beam_search
182
173
183
174
payload = {
184
- "prompts" : request_func_input .prompt ,
185
- "max_new_tokens" : request_func_input .output_len ,
186
- "ignore_eos" : True ,
187
- "do_sample" : True ,
188
- "temperature" :
189
- 0.01 , # deepspeed-mii does not accept 0.0 temperature.
175
+ "prompt" : request_func_input .prompt ,
176
+ "max_tokens" : request_func_input .output_len ,
177
+ "temperature" : 0.01 , # deepspeed-mii does not accept 0.0 temp.
190
178
"top_p" : 1.0 ,
191
179
}
192
180
output = RequestFuncOutput ()
193
181
output .prompt_len = request_func_input .prompt_len
194
182
195
- # DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
183
+ # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
196
184
# will use 0 as placeholder.
197
- # https://github.com/microsoft/DeepSpeed-MII/pull/311
185
+ # See https://github.com/microsoft/DeepSpeed-MII/pull/311
198
186
output .ttft = 0
199
187
200
188
st = time .perf_counter ()
201
189
try :
202
190
async with session .post (url = request_func_input .api_url ,
203
- json = payload ) as resp :
204
- if resp .status == 200 :
205
- parsed_resp = await resp .json ()
191
+ json = payload ) as response :
192
+ if response .status == 200 :
193
+ parsed_resp = await response .json ()
206
194
output .latency = time .perf_counter () - st
207
- output .generated_text = parsed_resp [0 ][ "generated_text" ]
195
+ output .generated_text = parsed_resp ["text" ][ 0 ]
208
196
output .success = True
209
197
else :
198
+ output .error = response .reason
210
199
output .success = False
211
- except ( aiohttp . ClientOSError , aiohttp . ServerDisconnectedError ) :
200
+ except Exception :
212
201
output .success = False
202
+ exc_info = sys .exc_info ()
203
+ output .error = "" .join (traceback .format_exception (* exc_info ))
213
204
214
205
if pbar :
215
206
pbar .update (1 )
@@ -221,7 +212,9 @@ async def async_request_openai_completions(
221
212
pbar : Optional [tqdm ] = None ,
222
213
) -> RequestFuncOutput :
223
214
api_url = request_func_input .api_url
224
- assert api_url .endswith ("v1/completions" )
215
+ assert api_url .endswith (
216
+ "v1/completions"
217
+ ), "OpenAI Completions API URL must end with 'v1/completions'."
225
218
226
219
async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
227
220
assert not request_func_input .use_beam_search
@@ -243,15 +236,12 @@ async def async_request_openai_completions(
243
236
generated_text = ""
244
237
ttft = 0
245
238
st = time .perf_counter ()
239
+ most_recent_timestamp = st
246
240
try :
247
241
async with session .post (url = api_url , json = payload ,
248
242
headers = headers ) as response :
249
243
if response .status == 200 :
250
244
async for chunk in response .content :
251
- if ttft == 0 :
252
- ttft = time .perf_counter () - st
253
- output .ttft = ttft
254
-
255
245
chunk = chunk .strip ()
256
246
if not chunk :
257
247
continue
@@ -260,16 +250,33 @@ async def async_request_openai_completions(
260
250
if chunk == "[DONE]" :
261
251
latency = time .perf_counter () - st
262
252
else :
263
- body = json .loads (chunk )
264
- generated_text += body ["choices" ][0 ]["text" ]
253
+ data = json .loads (chunk )
254
+
255
+ if data ["choices" ][0 ]["text" ]:
256
+ timestamp = time .perf_counter ()
257
+ # First token
258
+ if ttft == 0 :
259
+ ttft = time .perf_counter () - st
260
+ output .ttft = ttft
261
+
262
+ # Decoding phase
263
+ # NOTE: Some completion API might have a last
264
+ # usage summary response without a token so we
265
+ # do not want to include as inter-token-latency
266
+ elif data .get ("usage" , None ) is None :
267
+ output .itl .append (timestamp -
268
+ most_recent_timestamp )
269
+
270
+ most_recent_timestamp = timestamp
271
+ generated_text += data ["choices" ][0 ]["text" ]
265
272
266
273
output .generated_text = generated_text
267
274
output .success = True
268
275
output .latency = latency
269
- else :
270
- output .success = False
271
- except (aiohttp .ClientOSError , aiohttp .ServerDisconnectedError ):
276
+ except Exception :
272
277
output .success = False
278
+ exc_info = sys .exc_info ()
279
+ output .error = "" .join (traceback .format_exception (* exc_info ))
273
280
274
281
if pbar :
275
282
pbar .update (1 )
@@ -283,7 +290,7 @@ async def async_request_openai_chat_completions(
283
290
api_url = request_func_input .api_url
284
291
assert api_url .endswith (
285
292
"v1/chat/completions"
286
- ), "OpenAI Chat API URL must end with 'v1/chat/completions'."
293
+ ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
287
294
288
295
async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
289
296
assert not request_func_input .use_beam_search
@@ -301,7 +308,7 @@ async def async_request_openai_chat_completions(
301
308
}
302
309
headers = {
303
310
"Content-Type" : "application/json" ,
304
- "Authorization" : f"Bearer { os .environ .get ('OPENAI_API_KEY' )} "
311
+ "Authorization" : f"Bearer { os .environ .get ('OPENAI_API_KEY' )} " ,
305
312
}
306
313
307
314
output = RequestFuncOutput ()
@@ -310,15 +317,12 @@ async def async_request_openai_chat_completions(
310
317
generated_text = ""
311
318
ttft = 0
312
319
st = time .perf_counter ()
320
+ most_recent_timestamp = st
313
321
try :
314
322
async with session .post (url = api_url , json = payload ,
315
323
headers = headers ) as response :
316
324
if response .status == 200 :
317
325
async for chunk in response .content :
318
- if ttft == 0 :
319
- ttft = time .perf_counter () - st
320
- output .ttft = ttft
321
-
322
326
chunk = chunk .strip ()
323
327
if not chunk :
324
328
continue
@@ -327,18 +331,35 @@ async def async_request_openai_chat_completions(
327
331
if chunk == "[DONE]" :
328
332
latency = time .perf_counter () - st
329
333
else :
330
- body = json .loads (chunk )
331
- if "content" in body ["choices" ][0 ]["delta" ]:
332
- generated_text += body ["choices" ][0 ]["delta" ][
334
+ timestamp = time .perf_counter ()
335
+ data = json .loads (chunk )
336
+
337
+ if "content" in data ["choices" ][0 ]["delta" ]:
338
+ # First token
339
+ if ttft == 0 :
340
+ ttft = time .perf_counter () - st
341
+ output .ttft = ttft
342
+
343
+ # Decoding phase
344
+ else :
345
+ output .itl .append (timestamp -
346
+ most_recent_timestamp )
347
+
348
+ generated_text += data ["choices" ][0 ]["delta" ][
333
349
"content" ]
334
350
351
+ most_recent_timestamp = timestamp
352
+
335
353
output .generated_text = generated_text
336
354
output .success = True
337
355
output .latency = latency
338
356
else :
357
+ output .error = response .reason
339
358
output .success = False
340
- except ( aiohttp . ClientOSError , aiohttp . ServerDisconnectedError ) :
359
+ except Exception :
341
360
output .success = False
361
+ exc_info = sys .exc_info ()
362
+ output .error = "" .join (traceback .format_exception (* exc_info ))
342
363
343
364
if pbar :
344
365
pbar .update (1 )
@@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str:
355
376
356
377
ASYNC_REQUEST_FUNCS = {
357
378
"tgi" : async_request_tgi ,
358
- "vllm" : async_request_vllm ,
379
+ "vllm" : async_request_openai_completions ,
380
+ "lmdeploy" : async_request_openai_completions ,
359
381
"deepspeed-mii" : async_request_deepspeed_mii ,
360
382
"openai" : async_request_openai_completions ,
361
383
"openai-chat" : async_request_openai_chat_completions ,
0 commit comments