diff --git a/vllm_omni/benchmarks/__init__.py b/vllm_omni/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_omni/benchmarks/datasets.py b/vllm_omni/benchmarks/datasets.py new file mode 100644 index 000000000..e67eaa61e --- /dev/null +++ b/vllm_omni/benchmarks/datasets.py @@ -0,0 +1,1094 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena +""" +import ast +import base64 +import io +import json +import logging +import math +import os +import tempfile +from collections.abc import Iterator, Mapping +from contextlib import suppress +from typing import Any, cast, Dict + +import cv2 +import numpy as np +import torch +import torchaudio +from PIL import Image +from transformers import PreTrainedTokenizerBase +from vllm.benchmarks.datasets import (RandomDataset, ShareGPTDataset, SpecBench, + SonnetDataset, BurstGPTDataset, ConversationDataset, + VisionArenaDataset, MMVUDataset, InstructCoderDataset, MTBenchDataset, + BlazeditDataset, AIMODataset, NextEditPredictionDataset, ASRDataset, MLPerfDataset, + PrefixRepetitionRandomDataset, CustomDataset, SampleRequest, _ValidateDatasetArgs, + process_image) +from vllm.utils import PlaceholderModule + +try: + from datasets import load_dataset +except ImportError: + datasets = PlaceholderModule("datasets") + load_dataset = datasets.placeholder_attr("load_dataset") + +try: + import pandas as pd +except ImportError: + pd = PlaceholderModule("pandas") + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + +logger = logging.getLogger(__name__) + + +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and 'bytes' in video: + video_bytes = video['bytes'] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": { + "url": f"data:video/mp4;base64,{video_base64}" + }, + } + + if isinstance(video, str): + video_url = (video if video.startswith( + ("http://", "https://", "file://")) else f"file://{video}") + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + + +def process_audio(audio: Any) -> Mapping[str, Any]: + """ + Process a single audio input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw audio bytes: - Expects a dict with a 'bytes' key + containing raw audio data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the audio URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(audio, dict) and 'bytes' in audio: + audio_bytes = audio['bytes'] + audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') + return { + "type": "audio_url", + "audio_url": { + "url": f"data:audio/mpeg;base64,{audio_base64}" + }, + } + if isinstance(audio, str): + audio_url = (audio if audio.startswith( + ("http://", "https://", "file://")) else f"file://{audio}") + return {"type": "audio_url", "audio_url": {"url": audio_url}} + + raise ValueError(f"Invalid audio input {audio}. Must be a string of local path/remote url, or a dictionary with raw audio bytes in the form of `{{'bytes': raw_audio_bytes}}`." + ) + + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: supported via synthetic bytes data. + - Audio: supported via synthetic bytes data. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a “low-freq” mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, + height: int, + num_frames: int) -> Any: + """Generate synthetic video with random values. + """ + video_data = self._rng.integers( + 0, 256, + (num_frames, height, width, 3), + dtype=np.uint8, + ) + video_tensor = torch.from_numpy(video_data) + with tempfile.NamedTemporaryFile(suffix=f".mp4", delete=False) as tmp: + temp_path = tmp.name + frames, height, width, channels = video_tensor.shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(temp_path, fourcc, 30, (width, height)) + + for i in range(frames): + frame = video_tensor[i].numpy() + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame) + out.release() + + with open(temp_path, 'rb') as f: + video_bytes = f.read() + + os.unlink(temp_path) + + return { + 'bytes': video_bytes, + } + + + def generate_synthetic_audio( + self, + duration: int, # seconds + num_channels: int #1:Mono,2:Stereo 5:5.1 surround sound + ) -> Dict[str, Any]: + """Generate synthetic audio with random values. + Default use 48000Hz. + """ + sample_rate = 48000 + num_samples = int(sample_rate * duration) + audio_data = self._rng.uniform( + -0.5, 0.5, + (num_samples, num_channels) + ) + audio_data = np.clip(audio_data, -1.0, 1.0) + audio_tensor = torch.FloatTensor(audio_data.T) + buffer = io.BytesIO() + torchaudio.save( + buffer, + audio_tensor, + sample_rate, + format="mp3" + ) + buffer.seek(0) + audio_bytes = buffer.read() + return { + 'bytes': audio_bytes, + } + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[0] == 0: + return "audio" + elif config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + float]) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError("Got invalid bucket config. " + "Bucket config values must be non-zero.") + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + + def generate_mm_item(self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image( + mm_item_config[1], + mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video( + mm_item_config[1], + mm_item_config[0], + mm_item_config[2])) + elif self.map_config_to_modality(mm_item_config) == "audio": + return process_audio(self.generate_synthetic_audio( + mm_item_config[1], + mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: " + f"{mm_item_config}") + + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError(f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}") + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) + for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() + if k in allowed_modalities} + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in " + "bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, + math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError(f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}") + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int,int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 + for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + p=list(bucket_config_copy.values())) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield ( + mm_item_config + ) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning("Exhausted all multimodal items " + "of modality %s", + modality) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config( + bucket_config_copy) + + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[tuple[int, int, int], float] = + DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] + for i in range(num_requests): + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast(list[dict[str, Any]], [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ]) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + mm_requests.append(sample_request) + return mm_requests + + + +def add_dataset_parser(parser: FlexibleArgumentParser): + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="random", + action=_ValidateDatasetArgs, + choices=[ + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition", "spec_bench" + ], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Do not load the dataset in streaming mode.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + action=_ValidateDatasetArgs, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) + parser.add_argument( + "--no-oversample", + action="store_true", + help="Do not oversample if the dataset has " \ + "fewer samples than num-prompts.", + ) + + # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help= + "Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help= + "Skip applying chat template to prompt, used only for custom dataset.", + ) + + spec_bench_group = parser.add_argument_group("spec bench dataset options") + spec_bench_group.add_argument( + "--spec-bench-output-len", + type=int, + default=256, + help= + "Num of output tokens per request, used only for spec bench dataset.", + ) + spec_bench_group.add_argument( + "--spec-bench-category", + type=str, + default=None, + help= + "Category for spec bench dataset. If None, use all categories.", + ) + + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.", + ) + + blazedit_group = parser.add_argument_group("blazedit dataset options") + blazedit_group.add_argument( + "--blazedit-min-distance", + type=float, + default=0.0, + help= + "Minimum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + blazedit_group.add_argument( + "--blazedit-max-distance", + type=float, + default=1.0, + help= + "Maximum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=("Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]."), + ) + random_group.add_argument( + "--random-batch-size", + type=int, + default=1, + help=("Batch size for random sampling. " + "Only used for embeddings benchmark."), + ) + + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset") + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not (isinstance(key, tuple) and len(key) == 3 + and all(isinstance(x, int) for x in key)): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), + ) + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-name", + type=str, + default=None, + help=( + "Name of the dataset on HuggingFace " + "(e.g., 'lmarena-ai/VisionArena-Chat'). " + "Specify this if your dataset-path is a local path." + ), + ) + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + + +def get_samples(args, tokenizer) -> list[SampleRequest]: + + if not hasattr(args, "request_id_prefix"): + args.request_id_prefix = "" + + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ) + + elif args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.backend == "openai-chat": + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ) + + elif args.dataset_name == "hf": + # all following datasets are implemented from the + # HuggingFaceDataset base class + hf_kwargs = {} + if ( + args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = VisionArenaDataset + args.hf_split = "train" + args.hf_subset = None + elif ( + args.dataset_path in MMVUDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMVUDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMVUDataset + args.hf_split = "validation" + args.hf_subset = None + elif ( + args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = InstructCoderDataset + args.hf_split = "train" + elif ( + args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MTBenchDataset + args.hf_split = "train" + elif ( + args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ConversationDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = ConversationDataset + elif ( + args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS + or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = AIMODataset + args.hf_split = "train" + elif ( + args.dataset_path + in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501 + or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = NextEditPredictionDataset + args.hf_split = "train" + elif ( + args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = ASRDataset + args.hf_split = "train" + elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS: + dataset_class = BlazeditDataset + args.hf_split = "train" + hf_kwargs = { + "min_distance": args.blazedit_min_distance, + "max_distance": args.blazedit_max_distance, + } + elif ( + args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MLPerfDataset + args.hf_split = "train" + else: + supported_datasets = set([ + dataset_name for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ]) + raise ValueError( + f"Unsupported dataset path: {args.dataset_path}. " + "Huggingface dataset only supports dataset_path" + f" from one of following: {supported_datasets}. " + "Please consider contributing if you would " + "like to add support for additional dataset formats.") + + if dataset_class.IS_MULTIMODAL and args.backend not in [ + "openai-chat", + "openai-audio", + ]: + # multi-modal benchmark is only available on OpenAI Chat + # endpoint-type. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backends.") + input_requests = dataset_class( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + random_seed=args.seed, + no_stream=args.no_stream, + hf_name=args.hf_name, + ).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + **hf_kwargs + ) + + else: + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "spec_bench": + lambda: SpecBench(dataset_path=args.dataset_path, + category=args.spec_bench_category).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.spec_bench_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "random": lambda: RandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + no_oversample=args.no_oversample, + ), + "random-mm": + lambda: RandomMultiModalDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "prefix_repetition": + lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.prefix_repetition_prefix_len, + suffix_len=args.prefix_repetition_suffix_len, + num_prefixes=args.prefix_repetition_num_prefixes, + output_len=args.prefix_repetition_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + } + + try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.backend not in [ + "openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + + return input_requests \ No newline at end of file diff --git a/vllm_omni/benchmarks/lib/__init__.py b/vllm_omni/benchmarks/lib/__init__.py new file mode 100644 index 000000000..005e87af6 --- /dev/null +++ b/vllm_omni/benchmarks/lib/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark library utilities.""" diff --git a/vllm_omni/benchmarks/lib/endpoint_request_func.py b/vllm_omni/benchmarks/lib/endpoint_request_func.py new file mode 100644 index 000000000..d1b1d5053 --- /dev/null +++ b/vllm_omni/benchmarks/lib/endpoint_request_func.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""The request function for API endpoints.""" + +import io +import json +import os +import sys +import time +import traceback +from collections.abc import Awaitable +from dataclasses import dataclass, field +from typing import Optional, Protocol, Union +import aiohttp +from tqdm.asyncio import tqdm +from vllm.benchmarks.lib.endpoint_request_func import (async_request_openai_completions,async_request_openai_audio, + async_request_openai_embeddings, RequestFunc, + RequestFuncInput, + RequestFuncOutput,StreamedResponseHandler) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +@dataclass +class MixRequestFuncOutput(RequestFuncOutput): + output_audio_num: int = None + prompt_tokens: int = None + +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'.") + + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + mm_content = request_func_input.multi_modal_content + if isinstance(mm_content, list): + content.extend(mm_content) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "multi_modal_content must be a dict or list[dict] " + "for openai-chat" + ) + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "messages": [ + { + "role": "user", + "content": content + }, + ], + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + False + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id + + output = MixRequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + output.ttft = 0.0 + st = time.perf_counter() + output.start_time = st + output.output_audio_num = 0 + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + data = await response.json() + choices = data.get("choices") + for choice in choices: + content = choice["message"].get("content") + output.generated_text += content or "" + if choice["message"].get("audio"): + output.output_audio_num += 1 + usage = data.get("usage") + output.output_tokens = usage.get("completion_tokens") + output.prompt_tokens = usage.get("prompt_tokens") + output.success = True + output.latency = time.perf_counter() - st + output.ttft = output.latency + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + +# TODO: Add more request functions for different API protocols. +ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { + "vllm": async_request_openai_completions, + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, +} + +OPENAI_COMPATIBLE_BACKENDS = [ + k for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, + async_request_openai_chat_completions) +] diff --git a/vllm_omni/benchmarks/serve.py b/vllm_omni/benchmarks/serve.py new file mode 100644 index 000000000..2c29e2c70 --- /dev/null +++ b/vllm_omni/benchmarks/serve.py @@ -0,0 +1,1170 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +r"""Benchmark online serving throughput. + +On the server side, run one of the following commands +to launch the vLLM OpenAI API server: + vllm-omni serve + +On the client side, run: + vllm-omni bench serve \ + --backend \ + --label \ + --model \ + --dataset-name \ + --request-rate \ + --num-prompts +""" +import argparse +import asyncio +import gc +import importlib.util +import json +import os +import random +import shutil +import time +import warnings +from collections.abc import AsyncGenerator, Iterable +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Literal, Optional + +import aiohttp +import numpy as np +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from vllm_omni.benchmarks.datasets import get_samples,add_dataset_parser + +from vllm_omni.benchmarks.lib.endpoint_request_func import ( + ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS) +from vllm_omni.benchmarks.lib.endpoint_request_func import MixRequestFuncOutput + +from vllm.benchmarks.datasets import SampleRequest +from vllm.benchmarks.serve import (BenchmarkMetrics,EmbedBenchmarkMetrics,DeprecatedEndpointTypeAction, + TaskType,get_request,check_goodput_args,save_to_pytorch_benchmark_format) +from vllm.benchmarks.lib.endpoint_request_func import RequestFuncInput,RequestFuncOutput + +from vllm.benchmarks.lib.ready_checker import wait_for_endpoint +from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.transformers_utils.tokenizer import get_tokenizer + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + +TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None) + and (shutil.which("gnuplot") is not None)) + + + +@dataclass +class MixBenchmarkMetrics(BenchmarkMetrics): + audio_throughput: float + total_text_input: int + + + +def _get_current_request_rate( + ramp_up_strategy: Optional[Literal["linear", "exponential"]], + ramp_up_start_rps: Optional[int], + ramp_up_end_rps: Optional[int], + request_index: int, + total_requests: int, + request_rate: float, +) -> float: + if (ramp_up_strategy and ramp_up_start_rps is not None + and ramp_up_end_rps is not None): + progress = request_index / max(total_requests - 1, 1) + if ramp_up_strategy == "linear": + increase = (ramp_up_end_rps - ramp_up_start_rps) * progress + return ramp_up_start_rps + increase + elif ramp_up_strategy == "exponential": + ratio = ramp_up_end_rps / ramp_up_start_rps + return ramp_up_start_rps * (ratio**progress) + else: + raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") + return request_rate + + +def calculate_metrics_for_embeddings( + outputs: list[MixRequestFuncOutput], dur_s: float, + selected_percentiles: list[float]) -> EmbedBenchmarkMetrics: + """Calculate the metrics for the embedding requests. + + Args: + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + selected_percentiles: The percentiles to select. + + Returns: + The calculated benchmark metrics. + """ + total_input = 0 + completed = 0 + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + e2els.append(outputs[i].latency) + completed += 1 + total_input += outputs[i].prompt_len + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = EmbedBenchmarkMetrics( + completed=completed, + total_input=total_input, + request_throughput=completed / dur_s, + total_token_throughput=total_input / dur_s, + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + return metrics + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[MixBenchmarkMetrics, list[int]]: + """Calculate the metrics for the benchmark. + + Args: + input_requests: The input requests. + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + tokenizer: The tokenizer to use. + selected_percentiles: The percentiles to select. + goodput_config_dict: The goodput configuration. + + Returns: + A tuple of the benchmark metrics and the actual output lengths. + """ + actual_output_lens: list[int] = [] + total_input = 0 + total_text_input = 0 + completed = 0 + audio_completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if not output_len: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_text_input += input_requests[i].prompt_len + total_input += outputs[i].prompt_tokens + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + audio_completed += outputs[i].output_audio_num + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + + # Calculate max output tokens per second metric + max_output_tokens_per_s = 0.0 + max_concurrent_requests = 0 + + # Find the time range across all successful requests + successful_outputs = [output for output in outputs if output.success] + if successful_outputs: + min_start_time = min(output.start_time + for output in successful_outputs) + max_end_time = max(output.start_time + output.latency + for output in successful_outputs) + + # Create second buckets (ceiling to ensure we capture all time) + duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 + tokens_per_second = np.zeros(duration_seconds) + concurrent_requests_per_second = np.zeros(duration_seconds) + + for i, output in enumerate(successful_outputs): + # Calculate token generation timestamp using + # start_time, ttft, and itl + token_times = [output.start_time + output.ttft] + current_time = token_times[0] + for itl_value in output.itl: + current_time += itl_value + token_times.append(current_time) + + # Add tokens to second buckets + for token_time in token_times: + second_bucket = int(token_time - min_start_time) + if 0 <= second_bucket < duration_seconds: + tokens_per_second[second_bucket] += 1 + + # Track concurrent requests for each second this request was active + request_start_second = int(output.start_time - min_start_time) + request_end_second = int((output.start_time + output.latency) - + min_start_time) + + for second in range(request_start_second, request_end_second + 1): + concurrent_requests_per_second[second] += 1 + + # Find the maximum tokens per second and corresponding + # concurrent requests + if len(tokens_per_second) > 0: + max_output_tokens_per_s = float(np.max(tokens_per_second)) + max_concurrent_requests = int( + np.max(concurrent_requests_per_second)) + + if TERM_PLOTLIB_AVAILABLE: + import termplotlib as tpl + fig = tpl.figure() + fig.plot(np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second") + fig.plot(np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second") + fig.show() + else: + print("tip: install termplotlib and gnuplot to plot the metrics") + + metrics = MixBenchmarkMetrics( + completed=completed, + total_input=total_input, + total_text_input=total_text_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + audio_throughput=audio_completed / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by the endpoint + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + max_output_tokens_per_s=max_output_tokens_per_s, + max_concurrent_requests=max_concurrent_requests, + ) + + return metrics, actual_output_lens + + +async def benchmark( + endpoint_type: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + logprobs: Optional[int], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], + extra_body: Optional[dict], + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, + ready_check_timeout_sec: int = 600, +): + task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else + TaskType.GENERATION) + if endpoint_type in ASYNC_REQUEST_FUNCS: + if task_type == TaskType.EMBEDDING: + request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] + else: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + else: + raise ValueError(f"Unknown backend: {endpoint_type}") + + # Reuses connections across requests to reduce TLS handshake overhead. + connector = aiohttp.TCPConnector( + limit=max_concurrency or 0, + limit_per_host=max_concurrency or 0, + ttl_dns_cache=300, + use_dns_cache=True, + keepalive_timeout=60, + enable_cleanup_closed=True, + force_close=False, + ssl=("https://" in api_url), + ) + + session = aiohttp.ClientSession( + connector=connector, + trust_env=True, + timeout=aiohttp.ClientTimeout(total=6 * 60 * 60), + ) + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) + + assert (test_mm_content is None or isinstance(test_mm_content, dict) + or (isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content)) + ), "multi_modal_data must be a dict or list[dict]" + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + if ready_check_timeout_sec > 0: + test_output = await wait_for_endpoint( + request_func, + test_input, + session, + timeout_seconds=ready_check_timeout_sec, + ) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark " + "arguments are correctly specified. " + f"Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + else: + print("Skipping endpoint ready check.") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body) + profile_output = await request_func(request_func_input=profile_input, + session=session) + if profile_output.success: + print("Profiler started") + + distribution = ("Poisson process" + if burstiness == 1.0 else "Gamma distribution") + + if ramp_up_strategy is not None: + print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") + print(f"Will increase RPS from {ramp_up_start_rps} to " + f"{ramp_up_end_rps} RPS over the duration of the benchmark.") + else: + print(f"Traffic request rate: {request_rate}") + + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, session, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + session=session, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + session=session, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append({ + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + }) + + async for request, current_request_rate in get_request( + input_requests, request_rate, burstiness, ramp_up_strategy, + ramp_up_start_rps, ramp_up_end_rps): + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({ + "rps": rps_val, + "timestamp": timestamp + }) + last_int_rps = current_int_rps + prompt, prompt_len, output_len, mm_content, request_id = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + request.request_id, + ) + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + request_id=request_id, + ) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + session=session, + pbar=pbar))) + outputs: list[MixRequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + if max_concurrency is not None: + print("{:<40} {:<10}".format("Maximum request concurrency:", + max_concurrency)) + if request_rate != float('inf'): + print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", + request_rate)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total text input tokens:", metrics.total_text_input)) + if isinstance(metrics, MixBenchmarkMetrics): + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + print("{:<40} {:<10.2f}".format("Audio throughput (num/s):", + metrics.audio_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + if isinstance(metrics, MixBenchmarkMetrics): + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format( + "Peak output token throughput (tok/s):", + metrics.max_output_tokens_per_s)) + print("{:<40} {:<10.2f}".format("Peak concurrent requests:", + metrics.max_concurrent_requests)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + if isinstance(metrics, MixBenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_text_input_tokens": metrics.total_text_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "audio_throughput": metrics.audio_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "max_output_tokens_per_s": metrics.max_output_tokens_per_s, + "max_concurrent_requests": metrics.max_concurrent_requests, + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } + + if rps_change_events: + result["rps_change_events"] = rps_change_events + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + if task_type == TaskType.GENERATION: + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func(request_func_input=profile_input, + session=session) + if profile_output.success: + print("Profiler stopped") + + await session.close() + return result + + +def add_cli_args(parser: argparse.ArgumentParser): + add_dataset_parser(parser) + parser.add_argument( + "--label", + type=str, + default=None, + help="The label (prefix) of the benchmark results. If not specified, " + "the value of '--backend' will be used as the label.", + ) + parser.add_argument( + "--backend", + type=str, + default="openai", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + help="The type of backend or endpoint to use for the benchmark." + ) + parser.add_argument( + "--endpoint-type", + type=str, + default=None, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + action=DeprecatedEndpointTypeAction, + help="'--endpoint-type' is deprecated and will be removed in v0.11.0. " + "Please use '--backend' instead.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--header", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --header x-additional-info=0.3.3) " + "for headers to be passed with each request. These headers override " \ + "per backend constants and values set via environment variable, and " \ + "will be overriden by other arguments (such as request ids)." + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" # noqa + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl,e2el", + help="Comma-separated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-separated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\"." + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).", + ) + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + + parser.add_argument( + "--ramp-up-strategy", + type=str, + default=None, + choices=["linear", "exponential"], + help="The ramp-up strategy. This would be used to " + "ramp up the request rate from initial RPS to final " + "RPS rate (specified by --ramp-up-start-rps and " + "--ramp-up-end-rps.) over the duration of the benchmark.") + parser.add_argument( + "--ramp-up-start-rps", + type=int, + default=None, + help="The starting request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ramp-up-end-rps", + type=int, + default=None, + help="The ending request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ready-check-timeout-sec", + type=int, + default=600, + help="Maximum time to wait for the endpoint to become ready " + "in seconds (default: 600 seconds / 10 minutes). If set to 0, " + "the ready check will be skipped." + ) + + +def main(args: argparse.Namespace) -> dict[str, Any]: + return asyncio.run(main_async(args)) + + +async def main_async(args: argparse.Namespace) -> dict[str, Any]: + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + # Validate ramp-up arguments + if args.ramp_up_strategy is not None: + if args.request_rate != float("inf"): + raise ValueError( + "When using ramp-up, do not specify --request-rate. " + "The request rate will be controlled by ramp-up parameters. " + "Please remove the --request-rate argument.") + if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: + raise ValueError( + "When using --ramp-up-strategy, both --ramp-up-start-rps and " + "--ramp-up-end-rps must be specified") + if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: + raise ValueError("Ramp-up start and end RPS must be non-negative") + if args.ramp_up_start_rps > args.ramp_up_end_rps: + raise ValueError("Ramp-up start RPS must be less than end RPS") + if (args.ramp_up_strategy == "exponential" + and args.ramp_up_start_rps == 0): + raise ValueError( + "For exponential ramp-up, the start RPS cannot be 0.") + + label = args.label + model_id = args.model + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + # Headers + headers = None + if args.header: + headers = {} + for item in args.header: + if "=" in item: + kvstring = item.split("=", 1) + headers[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid header format. Please use KEY=VALUE format.") + + tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code) + + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") + + # Load the dataset. + input_requests = get_samples(args, tokenizer) + goodput_config_dict = check_goodput_args(args) + + # Collect the sampling parameters. + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature, + }.items() if v is not None + } + + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError("Sampling parameters are only supported by " + "openai-compatible backends.") + + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + + benchmark_result = await benchmark( + endpoint_type=args.backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_headers=headers, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + ) + + # Save config and results to json + result_json: dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = args.backend # for backward compatibility + result_json["backend"] = args.backend + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts + + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=", 1) + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format.") + + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + if field in benchmark_result: + del benchmark_result[field] + + # Save to file + if args.save_result or args.append_result: + base_model_id = model_id.split("/")[-1] + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + label = label or args.backend + if args.ramp_up_strategy is not None: + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) + file_name = os.path.join(args.result_dir, file_name) + with open(file_name, + mode="a+" if args.append_result else "w", + encoding="utf-8") as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") + json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) + + return result_json diff --git a/vllm_omni/entrypoints/cli/__init__.py b/vllm_omni/entrypoints/cli/__init__.py index b233a71e6..605b9cc7f 100644 --- a/vllm_omni/entrypoints/cli/__init__.py +++ b/vllm_omni/entrypoints/cli/__init__.py @@ -1,5 +1,6 @@ """CLI helpers for vLLM-Omni entrypoints.""" from .serve import OmniServeCommand +from vllm_omni.entrypoints.cli.benchmark.serve import OmniBenchmarkServingSubcommand -__all__ = ["OmniServeCommand"] +__all__ = ["OmniServeCommand", "OmniBenchmarkServingSubcommand"] diff --git a/vllm_omni/entrypoints/cli/benchmark/__init__.py b/vllm_omni/entrypoints/cli/benchmark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_omni/entrypoints/cli/benchmark/base.py b/vllm_omni/entrypoints/cli/benchmark/base.py new file mode 100644 index 000000000..b12e0fe02 --- /dev/null +++ b/vllm_omni/entrypoints/cli/benchmark/base.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.entrypoints.cli.types import CLISubcommand + + +class OmniBenchmarkSubcommandBase(CLISubcommand): + """ The base class of subcommands for vllm bench. """ + + help: str + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + """Add the CLI arguments to the parser.""" + raise NotImplementedError + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the benchmark. + + Args: + args: The arguments to the command. + """ + raise NotImplementedError diff --git a/vllm_omni/entrypoints/cli/benchmark/main.py b/vllm_omni/entrypoints/cli/benchmark/main.py new file mode 100644 index 000000000..3b4574b43 --- /dev/null +++ b/vllm_omni/entrypoints/cli/benchmark/main.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import typing + +from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + + +class OmniBenchmarkSubcommand(CLISubcommand): + """ The `bench` subcommand for the vLLM CLI. """ + + name = "bench" + help = "vLLM bench subcommand." + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + args.dispatch_function(args) + + def validate(self, args: argparse.Namespace) -> None: + pass + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + bench_parser = subparsers.add_parser( + self.name, + description=self.help, + usage=f"vllm {self.name} [options]") + bench_subparsers = bench_parser.add_subparsers(required=True, + dest="bench_type") + + + for cmd_cls in OmniBenchmarkSubcommandBase.__subclasses__(): + cmd_subparser = bench_subparsers.add_parser( + cmd_cls.name, + help=cmd_cls.help, + description=cmd_cls.help, + usage=f"vllm {self.name} {cmd_cls.name} [--omni] [options]", + ) + cmd_subparser.add_argument( + "--omni", + action="store_true", + default=True, + help="Enable benchmark-Omni mode (always enabled for omni commands)", + ) + cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd) + cmd_cls.add_cli_args(cmd_subparser) + + cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format( + subcmd=f"{self.name} {cmd_cls.name}") + + return bench_parser + + +def cmd_init() -> list[CLISubcommand]: + return [OmniBenchmarkSubcommand()] diff --git a/vllm_omni/entrypoints/cli/benchmark/serve.py b/vllm_omni/entrypoints/cli/benchmark/serve.py new file mode 100644 index 000000000..e69c9f77a --- /dev/null +++ b/vllm_omni/entrypoints/cli/benchmark/serve.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm_omni.benchmarks.serve import add_cli_args, main +from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase + + +class OmniBenchmarkServingSubcommand(OmniBenchmarkSubcommandBase): + """ The `serve` subcommand for vllm bench. """ + name = "serve" + help = "Benchmark the online serving throughput." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py index 6a65d9d6c..2a355be9b 100644 --- a/vllm_omni/entrypoints/cli/main.py +++ b/vllm_omni/entrypoints/cli/main.py @@ -19,9 +19,11 @@ def main(): from vllm.utils.argparse_utils import FlexibleArgumentParser import vllm_omni.entrypoints.cli.serve + import vllm_omni.entrypoints.cli.benchmark.main CMD_MODULES = [ vllm_omni.entrypoints.cli.serve, + vllm_omni.entrypoints.cli.benchmark.main, ] cli_env_setup()