Skip to content

Commit 45b6ef6

Browse files
authored
feat(benchmarks): Add Prefix Caching Benchmark to Serving Benchmark (vllm-project#3277)
1 parent 1956931 commit 45b6ef6

File tree

6 files changed

+897
-155
lines changed

6 files changed

+897
-155
lines changed

.buildkite/run-benchmarks.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
2323
# wait for server to start, timeout after 600 seconds
2424
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
2525
python3 benchmarks/benchmark_serving.py \
26-
--backend openai \
27-
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
26+
--backend vllm \
27+
--dataset-name sharegpt \
28+
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \
2829
--model meta-llama/Llama-2-7b-chat-hf \
2930
--num-prompts 20 \
3031
--endpoint /v1/completions \

benchmarks/backend_request_func.py

+121-99
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
import os
3+
import sys
34
import time
4-
from dataclasses import dataclass
5-
from typing import Optional
5+
import traceback
6+
from dataclasses import dataclass, field
7+
from typing import List, Optional
68

79
import aiohttp
810
from tqdm.asyncio import tqdm
@@ -26,8 +28,11 @@ class RequestFuncOutput:
2628
generated_text: str = ""
2729
success: bool = False
2830
latency: float = 0
29-
ttft: float = 0
31+
ttft: float = 0 # Time to first token
32+
itl: List[float] = field(
33+
default_factory=list) # List of inter-token latencies
3034
prompt_len: int = 0
35+
error: str = ""
3136

3237

3338
async def async_request_tgi(
@@ -55,71 +60,38 @@ async def async_request_tgi(
5560

5661
ttft = 0
5762
st = time.perf_counter()
63+
most_recent_timestamp = st
5864
try:
5965
async with session.post(url=api_url, json=payload) as response:
6066
if response.status == 200:
61-
async for data in response.content.iter_any():
62-
if ttft == 0:
63-
ttft = time.perf_counter() - st
64-
output.ttft = ttft
65-
output.latency = time.perf_counter() - st
66-
67-
body = remove_prefix(data.decode("utf-8"), "data:")
68-
output.generated_text = json.loads(body)["generated_text"]
69-
output.success = True
70-
else:
71-
output.success = False
72-
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
73-
output.success = False
74-
75-
if pbar:
76-
pbar.update(1)
77-
return output
78-
79-
80-
async def async_request_vllm(
81-
request_func_input: RequestFuncInput,
82-
pbar: Optional[tqdm] = None,
83-
) -> RequestFuncOutput:
84-
api_url = request_func_input.api_url
85-
assert api_url.endswith("generate")
67+
async for chunk in response.content:
68+
chunk = chunk.strip()
69+
if not chunk:
70+
continue
8671

87-
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
88-
payload = {
89-
"prompt": request_func_input.prompt,
90-
"n": 1,
91-
"best_of": request_func_input.best_of,
92-
"use_beam_search": request_func_input.use_beam_search,
93-
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
94-
"top_p": 1.0,
95-
"max_tokens": request_func_input.output_len,
96-
"ignore_eos": True,
97-
"stream": True,
98-
}
99-
output = RequestFuncOutput()
100-
output.prompt_len = request_func_input.prompt_len
72+
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
10173

102-
ttft = 0
103-
st = time.perf_counter()
104-
try:
105-
async with session.post(url=api_url, json=payload) as response:
106-
if response.status == 200:
107-
async for data in response.content.iter_any():
74+
data = json.loads(chunk)
75+
timestamp = time.perf_counter()
76+
# First token
10877
if ttft == 0:
10978
ttft = time.perf_counter() - st
11079
output.ttft = ttft
111-
output.latency = time.perf_counter() - st
11280

113-
# When streaming, '\0' is appended to the end of response.
114-
body = data.decode("utf-8").strip("\0")
115-
output.generated_text = json.loads(
116-
body)["text"][0][len(request_func_input.prompt):]
117-
output.success = True
81+
# Decoding phase
82+
else:
83+
output.itl.append(timestamp -
84+
most_recent_timestamp)
11885

119-
else:
120-
output.success = False
121-
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
86+
most_recent_timestamp = timestamp
87+
88+
output.latency = most_recent_timestamp - st
89+
output.success = True
90+
output.generated_text = data["generated_text"]
91+
except Exception:
12292
output.success = False
93+
exc_info = sys.exc_info()
94+
output.error = "".join(traceback.format_exception(*exc_info))
12395

12496
if pbar:
12597
pbar.update(1)
@@ -146,26 +118,45 @@ async def async_request_trt_llm(
146118
}
147119
output = RequestFuncOutput()
148120
output.prompt_len = request_func_input.prompt_len
149-
ttft = 0
150121

122+
ttft = 0
151123
st = time.perf_counter()
124+
most_recent_timestamp = st
152125
try:
153-
async with session.post(url=api_url, json=payload) as resp:
154-
if resp.status == 200:
155-
async for data in resp.content.iter_any():
126+
async with session.post(url=api_url, json=payload) as response:
127+
if response.status == 200:
128+
async for chunk in response.content:
129+
chunk = chunk.strip()
130+
if not chunk:
131+
continue
132+
133+
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
134+
135+
data = json.loads(chunk)
136+
timestamp = time.perf_counter()
137+
# First token
156138
if ttft == 0:
157139
ttft = time.perf_counter() - st
158140
output.ttft = ttft
159-
output.latency = time.perf_counter() - st
160141

161-
body = remove_prefix(data.decode("utf-8"), "data:")
162-
output.generated_text = json.loads(body)["text_output"]
142+
# Decoding phase
143+
else:
144+
output.itl.append(timestamp -
145+
most_recent_timestamp)
146+
147+
most_recent_timestamp = timestamp
148+
149+
output.latency = most_recent_timestamp - st
150+
output.generated_text = json.loads(data)["text_output"]
163151
output.success = True
164152

165153
else:
154+
output.error = response.reason
166155
output.success = False
167-
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
156+
except Exception:
168157
output.success = False
158+
exc_info = sys.exc_info()
159+
output.error = "".join(traceback.format_exception(*exc_info))
169160

170161
if pbar:
171162
pbar.update(1)
@@ -181,35 +172,35 @@ async def async_request_deepspeed_mii(
181172
assert not request_func_input.use_beam_search
182173

183174
payload = {
184-
"prompts": request_func_input.prompt,
185-
"max_new_tokens": request_func_input.output_len,
186-
"ignore_eos": True,
187-
"do_sample": True,
188-
"temperature":
189-
0.01, # deepspeed-mii does not accept 0.0 temperature.
175+
"prompt": request_func_input.prompt,
176+
"max_tokens": request_func_input.output_len,
177+
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
190178
"top_p": 1.0,
191179
}
192180
output = RequestFuncOutput()
193181
output.prompt_len = request_func_input.prompt_len
194182

195-
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
183+
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
196184
# will use 0 as placeholder.
197-
# https://github.com/microsoft/DeepSpeed-MII/pull/311
185+
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
198186
output.ttft = 0
199187

200188
st = time.perf_counter()
201189
try:
202190
async with session.post(url=request_func_input.api_url,
203-
json=payload) as resp:
204-
if resp.status == 200:
205-
parsed_resp = await resp.json()
191+
json=payload) as response:
192+
if response.status == 200:
193+
parsed_resp = await response.json()
206194
output.latency = time.perf_counter() - st
207-
output.generated_text = parsed_resp[0]["generated_text"]
195+
output.generated_text = parsed_resp["text"][0]
208196
output.success = True
209197
else:
198+
output.error = response.reason
210199
output.success = False
211-
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
200+
except Exception:
212201
output.success = False
202+
exc_info = sys.exc_info()
203+
output.error = "".join(traceback.format_exception(*exc_info))
213204

214205
if pbar:
215206
pbar.update(1)
@@ -221,7 +212,9 @@ async def async_request_openai_completions(
221212
pbar: Optional[tqdm] = None,
222213
) -> RequestFuncOutput:
223214
api_url = request_func_input.api_url
224-
assert api_url.endswith("v1/completions")
215+
assert api_url.endswith(
216+
"v1/completions"
217+
), "OpenAI Completions API URL must end with 'v1/completions'."
225218

226219
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
227220
assert not request_func_input.use_beam_search
@@ -243,15 +236,12 @@ async def async_request_openai_completions(
243236
generated_text = ""
244237
ttft = 0
245238
st = time.perf_counter()
239+
most_recent_timestamp = st
246240
try:
247241
async with session.post(url=api_url, json=payload,
248242
headers=headers) as response:
249243
if response.status == 200:
250244
async for chunk in response.content:
251-
if ttft == 0:
252-
ttft = time.perf_counter() - st
253-
output.ttft = ttft
254-
255245
chunk = chunk.strip()
256246
if not chunk:
257247
continue
@@ -260,16 +250,33 @@ async def async_request_openai_completions(
260250
if chunk == "[DONE]":
261251
latency = time.perf_counter() - st
262252
else:
263-
body = json.loads(chunk)
264-
generated_text += body["choices"][0]["text"]
253+
data = json.loads(chunk)
254+
255+
if data["choices"][0]["text"]:
256+
timestamp = time.perf_counter()
257+
# First token
258+
if ttft == 0:
259+
ttft = time.perf_counter() - st
260+
output.ttft = ttft
261+
262+
# Decoding phase
263+
# NOTE: Some completion API might have a last
264+
# usage summary response without a token so we
265+
# do not want to include as inter-token-latency
266+
elif data.get("usage", None) is None:
267+
output.itl.append(timestamp -
268+
most_recent_timestamp)
269+
270+
most_recent_timestamp = timestamp
271+
generated_text += data["choices"][0]["text"]
265272

266273
output.generated_text = generated_text
267274
output.success = True
268275
output.latency = latency
269-
else:
270-
output.success = False
271-
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
276+
except Exception:
272277
output.success = False
278+
exc_info = sys.exc_info()
279+
output.error = "".join(traceback.format_exception(*exc_info))
273280

274281
if pbar:
275282
pbar.update(1)
@@ -283,7 +290,7 @@ async def async_request_openai_chat_completions(
283290
api_url = request_func_input.api_url
284291
assert api_url.endswith(
285292
"v1/chat/completions"
286-
), "OpenAI Chat API URL must end with 'v1/chat/completions'."
293+
), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
287294

288295
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
289296
assert not request_func_input.use_beam_search
@@ -301,7 +308,7 @@ async def async_request_openai_chat_completions(
301308
}
302309
headers = {
303310
"Content-Type": "application/json",
304-
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
311+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
305312
}
306313

307314
output = RequestFuncOutput()
@@ -310,15 +317,12 @@ async def async_request_openai_chat_completions(
310317
generated_text = ""
311318
ttft = 0
312319
st = time.perf_counter()
320+
most_recent_timestamp = st
313321
try:
314322
async with session.post(url=api_url, json=payload,
315323
headers=headers) as response:
316324
if response.status == 200:
317325
async for chunk in response.content:
318-
if ttft == 0:
319-
ttft = time.perf_counter() - st
320-
output.ttft = ttft
321-
322326
chunk = chunk.strip()
323327
if not chunk:
324328
continue
@@ -327,18 +331,35 @@ async def async_request_openai_chat_completions(
327331
if chunk == "[DONE]":
328332
latency = time.perf_counter() - st
329333
else:
330-
body = json.loads(chunk)
331-
if "content" in body["choices"][0]["delta"]:
332-
generated_text += body["choices"][0]["delta"][
334+
timestamp = time.perf_counter()
335+
data = json.loads(chunk)
336+
337+
if "content" in data["choices"][0]["delta"]:
338+
# First token
339+
if ttft == 0:
340+
ttft = time.perf_counter() - st
341+
output.ttft = ttft
342+
343+
# Decoding phase
344+
else:
345+
output.itl.append(timestamp -
346+
most_recent_timestamp)
347+
348+
generated_text += data["choices"][0]["delta"][
333349
"content"]
334350

351+
most_recent_timestamp = timestamp
352+
335353
output.generated_text = generated_text
336354
output.success = True
337355
output.latency = latency
338356
else:
357+
output.error = response.reason
339358
output.success = False
340-
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
359+
except Exception:
341360
output.success = False
361+
exc_info = sys.exc_info()
362+
output.error = "".join(traceback.format_exception(*exc_info))
342363

343364
if pbar:
344365
pbar.update(1)
@@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str:
355376

356377
ASYNC_REQUEST_FUNCS = {
357378
"tgi": async_request_tgi,
358-
"vllm": async_request_vllm,
379+
"vllm": async_request_openai_completions,
380+
"lmdeploy": async_request_openai_completions,
359381
"deepspeed-mii": async_request_deepspeed_mii,
360382
"openai": async_request_openai_completions,
361383
"openai-chat": async_request_openai_chat_completions,

0 commit comments

Comments
 (0)