diff --git a/examples/nixl/disaggregated_prefill_server_launcher b/examples/nixl/disaggregated_prefill_server_launcher new file mode 100755 index 000000000..766440277 --- /dev/null +++ b/examples/nixl/disaggregated_prefill_server_launcher @@ -0,0 +1,544 @@ +#!/bin/bash + +BASH_DIR=$(dirname "${BASH_SOURCE[0]}") + +# Help function +show_help() { + cat << EOF +Usage: $0 [OPTIONS] + +Launch disaggregated vLLM servers for prefill or decode operations. + +OPTIONS: + -h, --help Show this help message + -m, --model MODEL Model to serve (default: ibm-research/PowerMoE-3b) + -r, --role ROLE Server role: prefill or decode (default: prefill) + -n, --num-instances NUM Number of local instances (default: 1) + -t, --tp-size SIZE Tensor parallel size (default: 1) + -d, --dp-size SIZE Data parallel size (default: 1) + --base-port PORT Base port for servers (default: 8300) + --base-channel-port PORT Base channel port (default: 4300) + --node-size SIZE Total number of nodes for this role's DP group (default: 1) + --node-rank RANK Data parallel node rank within this role's group (default: 0) + --node-ip IP IP address of this node (default: localhost) + --dp-master-ip IP Data parallel master IP (default: localhost) + --dp-master-port PORT Data parallel master port (default: 6300) + --nixl-buffer-device DEVICE Buffer device: cpu or hpu (default: cpu) + --nixl-backend BACKEND NIXL backend: UCX or other (default: UCX) + --ucx-tls TLS UCX transport layer (default: rc,ud,ib) + --max-model-len LENGTH Maximum model length (default: 8192) + --max-num-batched-tokens TOKENS Maximum number of batched tokens (default: 8192) + --max-num-seqs SEQS Maximum number of sequences (default: 256) + --max-cudagraph-capture-size SIZE Maximum CUDA graph capture size (default: 4096) + --gpu-memory-utilization RATIO GPU memory utilization ratio (default: 0.75) + --enforce-eager Enable eager execution mode + --debug Enable debug mode with reduced model layers + --apc Enable vLLM prefix cache + --profile Enable profiling (HABANA_PROFILE and VLLM_PROFILER_ENABLED) + --recipe-cache Enable HPU recipe cache + --async Enable async scheduling for decode role (adds --async-scheduling) + --warmup Enable vLLM warmup (do not skip warmup) + --no-ep Disable expert parallelism + --log-dir DIR Directory to save server logs (default: current directory) + Logs will be saved in DIR/xpyd_logs/YYYYMMDD_HHMMSS/ + +EXAMPLES: + # Launch prefill server with default settings + $0 + + # Launch decode server with 2 instances + $0 --role decode --num-instances 2 + + # Launch prefill server with custom model and tensor parallelism + $0 --model meta-llama/Llama-2-7b-hf --tp-size 4 + + # Launch on specific node in multi-node setup + $0 --node-rank 1 --node-ip 192.168.1.100 --dp-master-ip 192.168.1.50 + + # Launch with vLLM prefix cache enabled + $0 --apc + + # Launch with profiling enabled + $0 --profile + + # Launch decode server with async scheduling enabled + $0 --role decode --async + + # Launch with warmup enabled + $0 --warmup + + # Launch with custom log directory + $0 --log-dir /tmp/vllm-logs + # (Logs will be saved in /tmp/vllm-logs/xpyd_logs/YYYYMMDD_HHMMSS/) + +EOF +} + +# Default values +MODEL="ibm-research/PowerMoE-3b" +SERVER_ROLE="prefill" +NUM_LOCAL_INSTANCES=1 +BASE_PORT=8300 +BASE_CHANNEL_PORT=4300 +NODE_SIZE=1 +NODE_RANK=0 +NODE_IP="localhost" +TP_SIZE=1 +DP_SIZE=1 +DP_MASTER_IP="localhost" +DP_MASTER_PORT=6300 +NIXL_BUFFER_DEVICE="cpu" +VLLM_NIXL_BACKEND="UCX" +UCX_TLS="rc,ud,ib" +MAX_MODEL_LEN=8192 +MAX_NUM_BATCHED_TOKENS=8192 +MAX_NUM_SEQS=256 +MAX_CUDAGRAPH_CAPTURE_SIZE=4096 +GPU_MEMORY_UTILIZATION=0.75 +ENFORCE_EAGER=false +DEBUG=false +APC=false +PROFILE=false +RECIPE_CACHE=false +ASYNC=false +WARMUP=false +ENABLE_EXPERT_PARALLEL=true +LOG_DIR="." + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + -m|--model) + MODEL="$2" + shift 2 + ;; + -r|--role) + SERVER_ROLE="$2" + shift 2 + ;; + -n|--num-instances) + NUM_LOCAL_INSTANCES="$2" + shift 2 + ;; + -t|--tp-size) + TP_SIZE="$2" + shift 2 + ;; + -d|--dp-size) + DP_SIZE="$2" + shift 2 + ;; + --base-port) + BASE_PORT="$2" + shift 2 + ;; + --base-channel-port) + BASE_CHANNEL_PORT="$2" + shift 2 + ;; + --node-size) + NODE_SIZE="$2" + shift 2 + ;; + --node-rank) + NODE_RANK="$2" + shift 2 + ;; + --node-ip) + NODE_IP="$2" + shift 2 + ;; + --dp-master-ip) + DP_MASTER_IP="$2" + shift 2 + ;; + --dp-master-port) + DP_MASTER_PORT="$2" + shift 2 + ;; + --nixl-buffer-device) + NIXL_BUFFER_DEVICE="$2" + shift 2 + ;; + --nixl-backend) + VLLM_NIXL_BACKEND="$2" + shift 2 + ;; + --ucx-tls) + UCX_TLS="$2" + shift 2 + ;; + --max-model-len) + MAX_MODEL_LEN="$2" + shift 2 + ;; + --max-num-batched-tokens) + MAX_NUM_BATCHED_TOKENS="$2" + shift 2 + ;; + --max-num-seqs) + MAX_NUM_SEQS="$2" + shift 2 + ;; + --max-cudagraph-capture-size) + MAX_CUDAGRAPH_CAPTURE_SIZE="$2" + shift 2 + ;; + --gpu-memory-utilization) + GPU_MEMORY_UTILIZATION="$2" + shift 2 + ;; + --enforce-eager) + ENFORCE_EAGER=true + shift + ;; + --debug) + DEBUG=true + shift + ;; + --apc) + APC=true + shift + ;; + --profile) + PROFILE=true + shift + ;; + --recipe-cache) + RECIPE_CACHE=true + shift + ;; + --async) + ASYNC=true + shift + ;; + --warmup) + WARMUP=true + shift + ;; + --no-ep) + ENABLE_EXPERT_PARALLEL=false + shift + ;; + --log-dir) + LOG_DIR="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Environment variables +export no_proxy=localhost,${no_proxy} +export VLLM_USE_V1=1 +if [ "$WARMUP" = false ]; then + export VLLM_SKIP_WARMUP=True +fi +export PT_HPU_LAZY_MODE=1 +export PT_HPU_ENABLE_LAZY_COLLECTIVES=1 + +# Set flags based on --apc option +PREFIX_CACHE=() +if [ "$APC" = false ]; then + # APC will be disabled + PREFIX_CACHE+=(--no-enable-prefix-caching) +fi + +# Set profiling flags based on --profile option +if [ "$PROFILE" = true ]; then + export HABANA_PROFILE=1 + export VLLM_PROFILER_ENABLED=1 + echo "Profiling enabled: HABANA_PROFILE=1, VLLM_PROFILER_ENABLED=1" +fi + +# Set recipe cache based on --recipe-cache option +if [ "$RECIPE_CACHE" = true ]; then + export PT_HPU_RECIPE_CACHE_CONFIG="/workspace/pd_${SERVER_ROLE}_cache,false,131072" + echo "Recipe cache enabled: PT_HPU_RECIPE_CACHE_CONFIG=${PT_HPU_RECIPE_CACHE_CONFIG}" +fi + +export VLLM_SCALE_ADJUSTMENT=0 + +# Validate SERVER_ROLE +if [[ "$SERVER_ROLE" != "prefill" && "$SERVER_ROLE" != "decode" ]]; then + echo "Error: SERVER_ROLE ($SERVER_ROLE) must be either 'prefill' or 'decode'" + exit 1 +fi + +# Create log directory if it doesn't exist +RUN_TIMESTAMP=$(date +%Y%m%d_%H%M%S) +LOG_DIR_FULL="${LOG_DIR}/xpyd_logs/${RUN_TIMESTAMP}" +if [ ! -d "$LOG_DIR_FULL" ]; then + echo "Creating log directory: $LOG_DIR_FULL" + mkdir -p "$LOG_DIR_FULL" + if [ $? -ne 0 ]; then + echo "Error: Failed to create log directory: $LOG_DIR_FULL" + exit 1 + fi +fi + +# NIXL Config +export VLLM_NIXL_SIDE_CHANNEL_HOST=${NODE_IP} +if [ "$NIXL_BUFFER_DEVICE" == "cpu" ]; then + export VLLM_NIXL_DEVICE_TO_DEVICE=false +else + export VLLM_NIXL_DEVICE_TO_DEVICE=true + # Add gaudi_gdr to UCX_TLS if not already present + if [[ "$UCX_TLS" != *"gaudi_gdr"* ]]; then + UCX_TLS="${UCX_TLS},gaudi_gdr" + fi + export UCX_MEMTYPE_CACHE=0 +fi + +# Bucket settings +block_size=128 +input_min=128 +input_max=$MAX_MODEL_LEN +output_max=$MAX_MODEL_LEN +prompt_bs_step=2 +prompt_bs_min=1 +prompt_bs_max=$(( $MAX_NUM_BATCHED_TOKENS / $input_min )) +prompt_bs_max=$(( $prompt_bs_max > $MAX_NUM_SEQS ? $MAX_NUM_SEQS : $prompt_bs_max )) +# Hardcoded to 8 here for avoiding extra prompt bucket on decode nodes +prompt_bs_max=$(( $prompt_bs_max > 8 ? 8 : $prompt_bs_max )) +prompt_bs_max=$(( ($prompt_bs_max + $prompt_bs_step - 1) / $prompt_bs_step * $prompt_bs_step )) +prompt_seq_step=128 +prompt_seq_min=$(( ($input_min + $prompt_seq_step -1) / $prompt_seq_step * $prompt_seq_step )) +prompt_seq_max=$(( (($input_max + $prompt_seq_step -1) / $prompt_seq_step) * $prompt_seq_step )) +prompt_ctx_max=$(( ($MAX_MODEL_LEN - $block_size) / $block_size )) +decode_bs_step=$(( ($MAX_NUM_SEQS + 15) / 16 )) +decode_bs_min=1 +decode_bs_max=$(( ($MAX_NUM_SEQS + $decode_bs_step -1) / $decode_bs_step * $decode_bs_step )) +decode_block_step=$decode_bs_max +decode_block_min=$(( ($input_min + $block_size - 1) / $block_size )) +decode_block_min=$(( ($decode_block_min + $decode_block_step) / $decode_block_step * $decode_block_step )) +decode_block_max=$(( (($input_max + $output_max + $block_size -1) / $block_size + 1) * $decode_bs_max)) +# Set role-specific configurations +if [ "$SERVER_ROLE" == "prefill" ]; then + KV_ROLE="kv_producer" + export VLLM_SUPPORT_MOE_CHUNK="false" + # Bucket settings + export VLLM_EXPONENTIAL_BUCKETING=false + export VLLM_PROMPT_BS_BUCKET_MIN=$prompt_bs_min + export VLLM_PROMPT_BS_BUCKET_STEP=$prompt_bs_step + export VLLM_PROMPT_BS_BUCKET_MAX=$prompt_bs_max + export VLLM_PROMPT_QUERY_BUCKET_MIN=$prompt_seq_min + export VLLM_PROMPT_QUERY_BUCKET_STEP=$prompt_seq_step + export VLLM_PROMPT_QUERY_BUCKET_MAX=$prompt_seq_max + export VLLM_DECODE_BS_BUCKET_MIN=1 + export VLLM_DECODE_BS_BUCKET_STEP=1 + export VLLM_DECODE_BS_BUCKET_MAX=1 + export VLLM_DECODE_BLOCK_BUCKET_MIN=2 + export VLLM_DECODE_BLOCK_BUCKET_STEP=1 + export VLLM_DECODE_BLOCK_BUCKET_MAX=2 + if [ "$APC" = false ]; then + export VLLM_PROMPT_CTX_BUCKET_MAX=0 + fi +else + KV_ROLE="kv_consumer" + BASE_PORT=$((BASE_PORT+1000)) + BASE_CHANNEL_PORT=$((BASE_CHANNEL_PORT+1000)) + DP_MASTER_PORT=$((DP_MASTER_PORT+1000)) + # MoE settings + export VLLM_SUPPORT_MOE_CHUNK="true" + export PT_HPU_MOE_CHUNK="64, 128" + export PT_HPU_MOE_TOKEN_BOUNDARY="2048, 4096" + export VLLM_EXPONENTIAL_BUCKETING=true + # Bucket settings + # export VLLM_PROMPT_BS_BUCKET_MIN=$prompt_bs_min + # export VLLM_PROMPT_BS_BUCKET_STEP=$prompt_bs_step + # export VLLM_PROMPT_BS_BUCKET_MAX=$prompt_bs_max + export VLLM_PROMPT_QUERY_BUCKET_MIN=1 + # export VLLM_PROMPT_QUERY_BUCKET_MIN=$block_size + # export VLLM_PROMPT_QUERY_BUCKET_MIN=$block_size + # export VLLM_PROMPT_CTX_BUCKET_MAX=$prompt_ctx_max + # export VLLM_DECODE_BS_BUCKET_MIN=$decode_bs_min + # export VLLM_DECODE_BS_BUCKET_STEP=$decode_bs_step + # export VLLM_DECODE_BS_BUCKET_MAX=$decode_bs_max + # export VLLM_DECODE_BLOCK_BUCKET_MIN=$decode_block_min + # export VLLM_DECODE_BLOCK_BUCKET_STEP=$decode_block_step + # export VLLM_DECODE_BLOCK_BUCKET_MAX=$decode_block_max +fi + +# Check if DP_SIZE is 1 or equal to NUM_LOCAL_INSTANCES +if (( DP_SIZE != 1 && DP_SIZE != NUM_LOCAL_INSTANCES * NODE_SIZE )); then + echo "Error: DP_SIZE ($DP_SIZE) must be 1 or equal to NUM_LOCAL_INSTANCES ($NUM_LOCAL_INSTANCES) * NODE_SIZE ($NODE_SIZE)" + exit 1 +fi + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until env http_proxy=\"\" https_proxy=\"\" HTTP_PROXY=\"\" HTTPS_PROXY=\"\" \ + curl -s ${NODE_IP}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to run tests for a specific model +launch_vllm_server() { + local model_name=$1 + echo "Launching $SERVER_ROLE server with model: $model_name" + + # Arrays to store all hosts and ports + HOSTS=() + PORTS=() + + # Start instances + for i in $(seq 0 $((NUM_LOCAL_INSTANCES-1))); do + # Calculate port number (base port + instance number) + PORT=$((BASE_PORT+8*NODE_RANK+i)) + # Calculate side channel port. Avoid clash with with TP workers. + SIDE_CHANNEL_PORT=$((BASE_CHANNEL_PORT+8*NODE_RANK+i)) + + echo "Starting local instance $i on node $NODE_RANK, port $PORT" + + DP_ARGS=() + if [ "$DP_SIZE" -gt 1 ]; then + DP_ARGS+=(--data-parallel-size $DP_SIZE) + DP_ARGS+=(--data-parallel-rank $((NUM_LOCAL_INSTANCES*NODE_RANK+i))) + DP_ARGS+=(--data-parallel-addr $DP_MASTER_IP) + DP_ARGS+=(--data-parallel-rpc-port $DP_MASTER_PORT) + fi + + # Build debug args if enabled + DEBUG_ARGS=() + if [ "$DEBUG" = true ]; then + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_DEBUG=fwd + # DEBUG_ARGS+=(--hf-overrides \'{\"num_hidden_layers\": 6}\') + fi + + # Build eager execution args if enabled + EAGER_ARGS=() + if [ "$ENFORCE_EAGER" = true ]; then + EAGER_ARGS+=(--enforce-eager) + fi + + # Build async scheduling args if enabled (decode role only) + ASYNC_ARGS=() + if [ "$ASYNC" = true ] && [ "$SERVER_ROLE" == "decode" ]; then + ASYNC_ARGS+=(--async-scheduling) + fi + + # Build expert parallel args if enabled + EXPERT_PARALLEL_ARGS=() + if [ "$ENABLE_EXPERT_PARALLEL" = true ]; then + EXPERT_PARALLEL_ARGS+=(--enable-expert-parallel) + fi + + # Build the command with or without model-specific args + BASE_CMD="UCX_TLS=${UCX_TLS} VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT" + + # Add profiling directory if profiling is enabled + if [ "$PROFILE" = true ]; then + PROFILE_DIR="${LOG_DIR_FULL}/profile_outputs_${SERVER_ROLE}_node_${NODE_RANK}_rank_${i}/" + BASE_CMD="${BASE_CMD} VLLM_TORCH_PROFILER_DIR=${PROFILE_DIR}" + echo "Profile output directory for instance $i: $PROFILE_DIR" + fi + + BASE_CMD="${BASE_CMD} vllm serve $model_name \ + --port $PORT \ + --long_prefill_token_threshold 8192 \ + --max_num_batched_tokens $MAX_NUM_BATCHED_TOKENS \ + --max-model-len $MAX_MODEL_LEN \ + --max-num-seqs $MAX_NUM_SEQS \ + --max-cudagraph-capture-size $MAX_CUDAGRAPH_CAPTURE_SIZE \ + --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ + --disable-log-requests \ + --trust-remote-code \ + --tensor-parallel-size $TP_SIZE \ + ${PREFIX_CACHE[@]} \ + ${DP_ARGS[@]} \ + ${DEBUG_ARGS[@]} \ + ${EAGER_ARGS[@]} \ + ${ASYNC_ARGS[@]} \ + ${EXPERT_PARALLEL_ARGS[@]} \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"${KV_ROLE}\",\"kv_buffer_device\":\"${NIXL_BUFFER_DEVICE}\", \"kv_connector_extra_config\":{\"backends\":[\"${VLLM_NIXL_BACKEND}\"]}}'" + + FULL_CMD="$BASE_CMD" + echo $FULL_CMD + + LOG_FILE="${LOG_DIR_FULL}/vllm_server_${SERVER_ROLE}_node_${NODE_RANK}_rank_${i}.log" + echo "Logging to: $LOG_FILE" + eval "$FULL_CMD &> \"$LOG_FILE\" &" + + # Store host and port for proxy configuration + HOSTS+=($NODE_IP) + PORTS+=($PORT) + done + + # Wait for all instances to start (in parallel) + echo "Waiting for all $SERVER_ROLE instances to start..." + wait_pids=() + for PORT in "${PORTS[@]}"; do + echo "Checking $SERVER_ROLE instance on port $PORT..." + wait_for_server $PORT & + wait_pids+=($!) + done + + # Wait for all background processes to complete + for pid in "${wait_pids[@]}"; do + wait $pid + if [ $? -ne 0 ]; then + echo "Error: Failed to start server on one of the ports" + exit 1 + fi + done + + # Print all launched servers + echo "==============================================" + echo "All $SERVER_ROLE servers launched successfully" + echo "Run ID: $RUN_TIMESTAMP" + echo "Log directory: $LOG_DIR_FULL" + echo "==============================================" + for i in "${!HOSTS[@]}"; do + echo "Server $((i+1)): ${HOSTS[$i]}:${PORTS[$i]}" + done + echo "==============================================" + + # Show profiling control commands if profiling is enabled + if [ "$PROFILE" = true ]; then + echo "" + echo "Profiling Control Commands:" + echo "==============================================" + + # Build the port list from actual PORTS array + if [ ${#PORTS[@]} -gt 1 ]; then + # Multiple ports - create comma-separated list + port_list="" + for i in "${!PORTS[@]}"; do + if [ $i -eq 0 ]; then + port_list="${PORTS[$i]}" + else + port_list="${port_list},${PORTS[$i]}" + fi + done + echo "Start profiling for all instances:" + echo " vllm_start_profile ${NODE_IP}:${port_list}" + echo "" + echo "Stop profiling for all instances:" + echo " vllm_stop_profile ${NODE_IP}:${port_list}" + else + # Single port + echo "Start profiling:" + echo " vllm_start_profile ${NODE_IP}:${PORTS[0]}" + echo "" + echo "Stop profiling:" + echo " vllm_stop_profile ${NODE_IP}:${PORTS[0]}" + fi + echo "==============================================" + fi +} + +launch_vllm_server "$MODEL" + diff --git a/examples/nixl/proxy_server.py b/examples/nixl/proxy_server.py new file mode 100644 index 000000000..f6405de17 --- /dev/null +++ b/examples/nixl/proxy_server.py @@ -0,0 +1,977 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import ipaddress +import itertools +import json +import logging +import os +import httpx + +os.environ["PT_HPU_LAZY_MODE"] = "1" +import sys +import threading +import time +from abc import ABC, abstractmethod +from typing import Callable, Optional + +import aiohttp +import requests +import uuid +import uvicorn +from colorlog.escape_codes import escape_codes +from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status +from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse +from transformers import AutoTokenizer +from asyncio import CancelledError + +formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S") +handler = logging.StreamHandler() +handler.setFormatter(formatter) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +from fastapi.middleware.cors import CORSMiddleware + + +def log_info_blue(msg): + logger.info("%s%s%s", escape_codes["cyan"], msg, escape_codes["reset"]) + + +def log_info_green(msg): + logger.info("%s%s%s", escape_codes["green"], msg, escape_codes["reset"]) + + +def log_info_yellow(msg): + logger.info("%s%s%s", escape_codes["yellow"], msg, escape_codes["reset"]) + + +def log_info_red(msg): + logger.info("%s%s%s", escape_codes["red"], msg, escape_codes["reset"]) + + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None, connect=None, sock_read=None, sock_connect=None) + + +async def D_first_token_generator( + generator_d, + callback_owner=None, + decode_instance: str = None, + req_len: int = None, +): + try: + async for chunk in generator_d: + yield chunk + finally: + if callback_owner: + callback_owner.exception_handler(prefill_instance=None, decode_instance=decode_instance, req_len=req_len) + + +class SchedulingPolicy(ABC): + def __init__(self): + self.lock = threading.Lock() + + @abstractmethod + def schedule(self, cycler: itertools.cycle): + raise NotImplementedError("Scheduling Proxy is not set.") + + +class Proxy: + def __init__( + self, + prefill_instances: list[str], + decode_instances: list[str], + model: str, + scheduling_policy: SchedulingPolicy, + custom_create_completion: Optional[Callable[[Request], StreamingResponse]] = None, + custom_create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None, + ): + self.prefill_instances = prefill_instances + self.decode_instances = decode_instances + self.prefill_cycler = itertools.cycle(prefill_instances) + self.decode_cycler = itertools.cycle(decode_instances) + self.model = model + self.scheduling_policy = scheduling_policy + self.custom_create_completion = custom_create_completion + self.custom_create_chat_completion = custom_create_chat_completion + self.router = APIRouter() + self.setup_routes() + self.generator = D_first_token_generator + self.tokenizer = AutoTokenizer.from_pretrained(model) + + def on_done( + self, + prefill_instance: str = None, + decode_instance: str = None, + req_len: int = None, + ): + self.schedule_completion(prefill_instance, decode_instance, req_len=req_len) + + def setup_routes(self): + self.router.post("/v1/completions", dependencies=[Depends(self.validate_json_request)])( + self.custom_create_completion if self.custom_create_completion else self.create_completion + ) + self.router.post("/v1/chat/completions", dependencies=[Depends(self.validate_json_request)])( + self.custom_create_chat_completion if self.custom_create_chat_completion else self.create_chat_completion + ) + + self.router.options("/v1/completions")(lambda: None) + self.router.options("/v1/chat/completions")(lambda: None) + self.router.options("/v1/models")(lambda: None) + self.router.options("/status")(lambda: None) + self.router.options("/health")(lambda: None) + self.router.options("/ping")(lambda: None) + self.router.options("/tokenize")(lambda: None) + self.router.options("/detokenize")(lambda: None) + self.router.options("/version")(lambda: None) + self.router.options("/v1/embeddings")(lambda: None) + self.router.options("/pooling")(lambda: None) + self.router.options("/score")(lambda: None) + self.router.options("/v1/score")(lambda: None) + self.router.options("/rerank")(lambda: None) + self.router.options("/v1/rerank")(lambda: None) + self.router.options("/v2/rerank")(lambda: None) + self.router.options("/invocations")(lambda: None) + + self.router.get("/status", response_class=JSONResponse)(self.get_status) + self.router.post("/instances/add", dependencies=[Depends(self.api_key_authenticate)])( + self.add_instance_endpoint + ) + self.router.get("/health", response_class=PlainTextResponse)(self.get_health) + self.router.get("/ping", response_class=PlainTextResponse)(self.get_ping) + self.router.post("/ping", response_class=PlainTextResponse)(self.get_ping) + self.router.post("/tokenize", response_class=JSONResponse)(self.post_tokenize) + self.router.post("/detokenize", response_class=JSONResponse)(self.post_detokenize) + self.router.get("/v1/models", response_class=JSONResponse)(self.get_models) + self.router.get("/version", response_class=JSONResponse)(self.get_version) + self.router.post("/v1/embeddings", response_class=JSONResponse)(self.post_embeddings) + self.router.post("/pooling", response_class=JSONResponse)(self.post_pooling) + self.router.post("/score", response_class=JSONResponse)(self.post_score) + self.router.post("/v1/score", response_class=JSONResponse)(self.post_scorev1) + self.router.post("/rerank", response_class=JSONResponse)(self.post_rerank) + self.router.post("/v1/rerank", response_class=JSONResponse)(self.post_rerankv1) + self.router.post("/v2/rerank", response_class=JSONResponse)(self.post_rerankv2) + self.router.post("/invocations", response_class=JSONResponse)(self.post_invocations) + + async def get_from_instance(self, path: str, is_full_instancelist: int = 0): + if not self.prefill_instances: + return JSONResponse(content={"error": "No instances available"}, status_code=500) + + if is_full_instancelist == 0: + instances = [self.prefill_instances[0]] + else: + instances = self.prefill_instances + self.decode_instances + + results = {} + async with aiohttp.ClientSession() as session: + for inst in instances: + url = f"http://{inst}{path}" + try: + async with session.get(url) as resp: + try: + data = await resp.json() + dtype = "json" + except aiohttp.ContentTypeError: + data = await resp.text() + dtype = "text" + results[inst] = { + "status": resp.status, + "type": dtype, + "data": data, + } + except Exception as e: + results[inst] = {"status": 500, "error": str(e)} + print(f"Failed to fetch {url}: {e}, continue...") + + return JSONResponse(content=results, status_code=200) + + async def get_version(self): + return await self.get_from_instance("/version") + + async def get_models(self): + return await self.get_from_instance("/v1/models") + + async def get_health(self): + return await self.get_from_instance("/health", is_full_instancelist=1) + + async def get_ping(self): + return await self.get_from_instance("/ping", is_full_instancelist=1) + + async def post_to_instance(self, request: Request, path: str, json_template: dict): + body = await request.json() + + missing = [k for k in json_template if k not in body] + if missing: + return JSONResponse( + {"error": f"Missing required fields: {', '.join(missing)}"}, + status_code=400, + ) + + payload = json_template.copy() + payload.update(body) + + url = f"http://{self.prefill_instances[0]}{path}" + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, json=payload) as resp: + try: + content = await resp.json() + except aiohttp.ContentTypeError: + content = {"raw": await resp.text()} + return JSONResponse(content, status_code=resp.status) + except Exception as e: + return JSONResponse({"error": f"Failed to fetch {url}, reason: {str(e)}"}, status_code=500) + + async def post_detokenize(self, request: Request): + json_template = {"model": "", "tokens": []} + return await self.post_to_instance(request, "/detokenize", json_template) + + async def post_tokenize(self, request: Request): + json_template = {"model": "", "prompt": ""} + return await self.post_to_instance(request, "/tokenize", json_template) + + async def post_embeddings(self, request: Request): + json_template = {"model": "", "input": ""} + return await self.post_to_instance(request, "/v1/embeddings", json_template) + + async def post_pooling(self, request: Request): + json_template = {"model": "", "messages": ""} + return await self.post_to_instance(request, "/pooling", json_template) + + async def post_score(self, request: Request): + json_template = {"model": "", "text_1": "", "text_2": "", "predictions": ""} + return await self.post_to_instance(request, "/score", json_template) + + async def post_scorev1(self, request: Request): + json_template = {"model": "", "text_1": "", "text_2": "", "predictions": ""} + return await self.post_to_instance(request, "/v1/score", json_template) + + async def post_rerank(self, request: Request): + json_template = {"model": "", "query": "", "documents": ""} + return await self.post_to_instance(request, "/rerank", json_template) + + async def post_rerankv1(self, request: Request): + json_template = {"model": "", "query": "", "documents": ""} + return await self.post_to_instance(request, "/v1/rerank", json_template) + + async def post_rerankv2(self, request: Request): + json_template = {"model": "", "query": "", "documents": ""} + return await self.post_to_instance(request, "/v2/rerank", json_template) + + async def post_invocations(self, request: Request): + json_template = {"model": "", "prompt": ""} + return await self.post_to_instance(request, "/invocations", json_template) + + async def validate_json_request(self, raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=415, + detail="Unsupported Media Type: Only 'application/json' is allowed", + ) + + def api_key_authenticate(self, x_api_key: str = Header(...)): + expected_api_key = os.environ.get("ADMIN_API_KEY") + if not expected_api_key: + logger.error("ADMIN_API_KEY is not set in the environment.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Server configuration error.", + ) + if x_api_key != expected_api_key: + logger.warning("Unauthorized access attempt with API Key: %s", x_api_key) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Forbidden: Invalid API Key.", + ) + + async def validate_instance(self, instance: str) -> bool: + url = f"http://{instance}/v1/models" + try: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client: + logger.info("Verifying %s ...", instance) + async with client.get(url) as response: + if response.status == 200: + data = await response.json() + if "data" in data and len(data["data"]) > 0: + model_cur = data["data"][0].get("id", "") + if model_cur == self.model: + logger.info("Instance: %s could be added.", instance) + return True + else: + logger.warning( + "Mismatch model %s : %s != %s", + instance, + model_cur, + self.model, + ) + return False + else: + return False + else: + return False + except aiohttp.ClientError as e: + logger.error(str(e)) + return False + except Exception as e: + logger.error(str(e)) + return False + + async def add_instance_endpoint(self, request: Request): + try: + data = await request.json() + logger.warning(str(data)) + instance_type = data.get("type") + instance = data.get("instance") + if instance_type not in ["prefill", "decode"]: + raise HTTPException(status_code=400, detail="Invalid instance type.") + if not instance or ":" not in instance: + raise HTTPException(status_code=400, detail="Invalid instance format.") + host, port_str = instance.split(":") + try: + if host != "localhost": + ipaddress.ip_address(host) + port = int(port_str) + if not (0 < port < 65536): + raise HTTPException(status_code=400, detail="Invalid port number.") + except Exception as e: + raise HTTPException(status_code=400, detail="Invalid instance address.") from e + + is_valid = await self.validate_instance(instance) + if not is_valid: + raise HTTPException(status_code=400, detail="Instance validation failed.") + + if instance_type == "prefill": + with self.scheduling_policy.lock: + if instance not in self.prefill_instances: + self.prefill_instances.append(instance) + self.prefill_cycler = itertools.cycle(self.prefill_instances) + else: + raise HTTPException(status_code=400, detail="Instance already exists.") + else: + with self.scheduling_policy.lock: + if instance not in self.decode_instances: + self.decode_instances.append(instance) + self.decode_cycler = itertools.cycle(self.decode_instances) + else: + raise HTTPException(status_code=400, detail="Instance already exists.") + + return JSONResponse(content={"message": f"Added {instance} to {instance_type}_instances."}) + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.error("Error in add_instance_endpoint: %s", str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e + + async def forward_request(self, url, data, request_id: str, use_chunked=True): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + try: + async with session.post(url=url, json=data, headers=headers) as response: + if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 + if use_chunked: + async for chunk_bytes in response.content.iter_chunked( # noqa: E501 + 1024 + ): + yield chunk_bytes + else: + content = await response.read() + yield content + else: + error_content = await response.text() + try: + error_content = json.loads(error_content) + except json.JSONDecodeError: + error_content = error_content + logger.error( + "Request failed with status %s: %s", + response.status, + error_content, + ) + raise HTTPException( + status_code=response.status, + detail=f"Request failed with status {response.status}: {error_content}", + ) + except aiohttp.ClientError as e: + logger.error("ClientError occurred: %s", str(e)) + raise HTTPException( + status_code=502, + detail="Bad Gateway: Error communicating with upstream server.", + ) from e + except Exception as e: + logger.error("Unexpected error: %s", str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e + + def schedule( + self, + cycler: itertools.cycle, + is_prompt: int = None, + request_len: Optional[int] = None, + ) -> str: + return self.scheduling_policy.schedule(cycler, is_prompt, request_len) + + def schedule_completion( + self, + prefill_instance: str = None, + decode_instance: str = None, + req_len: int = None, + ): + self.scheduling_policy.schedule_completion( + prefill_instance=prefill_instance, + decode_instance=decode_instance, + req_len=req_len, + ) + + async def get_status(self): + status = { + "prefill_node_count": len(self.prefill_instances), + "decode_node_count": len(self.decode_instances), + "prefill_nodes": self.prefill_instances, + "decode_nodes": self.decode_instances, + } + return status + + def get_total_token_length(self, prompt): + fake_len = 100 + if isinstance(prompt, str): + return len(self.tokenizer(prompt)["input_ids"]) + elif isinstance(prompt, list): + if all(isinstance(p, str) for p in prompt): + return sum(len(self.tokenizer(p)["input_ids"]) for p in prompt) + elif all(isinstance(p, list) and all(isinstance(x, int) for x in p) for p in prompt): + # Already tokenized + return sum(len(p) for p in prompt) + else: + logger.error( + "Unsupported prompt format: %s / nested types. Value: %r", + type(prompt), + prompt, + ) + return fake_len + else: + logger.error("Unsupported prompt type: %s", type(prompt)) + return fake_len + + def exception_handler(self, prefill_instance=None, decode_instance=None, req_len=None): + if prefill_instance or decode_instance: + try: + self.on_done( + prefill_instance=prefill_instance, + decode_instance=decode_instance, + req_len=req_len, + ) + except Exception as e: + logger.error(f"Error releasing instances: {e}") + raise + + async def send_request_to_service( + self, instance: str, endpoint: str, req_data: dict, request_id: str + ): # yapf: disable + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + prefiller_base_url = f"http://{instance}/" + client = httpx.AsyncClient(timeout=None, base_url=prefiller_base_url) + + response = await client.post( + endpoint, json=req_data, headers=headers + ) # yapf: disable + response.raise_for_status() + + return response + + async def create_completion(self, raw_request: Request): + try: + request = await raw_request.json() + request_id = str(uuid.uuid4()) + + total_length = 0 + prefill_instance = None + decode_instance = None + + kv_prepare_request = request.copy() + + start_time = time.time() + prompt = kv_prepare_request.get("prompt") + total_length = self.get_total_token_length(prompt) + end_time = time.time() + + log_info_green( + f"create_completion -- prompt length: {total_length}, " + f"tokenizer took " + f"{(end_time - start_time) * 1000:.2f} ms" + ) + prefill_instance = self.schedule(self.prefill_cycler, is_prompt=True, request_len=total_length) + + # Send request to prefill service + response = await self.send_request_to_service( + prefill_instance, "/v1/completions", kv_prepare_request, request_id + ) # yapf: disable + + # Perform kv recv and decoding stage + response_json = response.json() + kv_transfer_params = response_json.get("kv_transfer_params", {}) + if kv_transfer_params: + request["kv_transfer_params"] = kv_transfer_params + decode_instance = self.schedule(self.decode_cycler, is_prompt=False, request_len=total_length) + try: + generator_d = self.forward_request(f"http://{decode_instance}/v1/completions", request, request_id) + except HTTPException as http_exc: + self.exception_handler(prefill_instance, decode_instance, total_length) + raise http_exc + + if request.get("stream", False): + generator_class = self.generator + else: + # For stream=False request, cannot use P first token + generator_class = D_first_token_generator + final_generator = generator_class( + generator_d, + self, + decode_instance, + req_len=total_length, + ) + media_type = "text/event-stream" if request.get("stream", False) else "application/json" + + async def wrapped_generator(): + try: + async for chunk in final_generator: + yield chunk + except CancelledError: + logger.warning("[0] Client disconnected during create_completion (CancelledError)") + except Exception as e: + logger.error("[1] Exception in wrapped_generator: %s", str(e)) + raise + + return StreamingResponse(wrapped_generator(), media_type=media_type) + except Exception: + exc_info = sys.exc_info() + print("Error occurred in disagg proxy server") + print(exc_info) + + async def create_chat_completion(self, raw_request: Request): + try: + request = await raw_request.json() + request_id = str(uuid.uuid4()) + + total_length = 0 + prefill_instance = None + decode_instance = None + + kv_prepare_request = request.copy() + + start_time = time.time() + # prefill stage + total_length = sum(self.get_total_token_length(msg["content"]) for msg in kv_prepare_request["messages"]) + end_time = time.time() + log_info_green( + f"create_chat_completion -- prompt length: {total_length}, " + f"tokenizer took " + f"{(end_time - start_time) * 1000:.2f} ms" + ) + + prefill_instance = self.schedule(self.prefill_cycler, is_prompt=True, request_len=total_length) + + # Send request to prefill service + response = await self.send_request_to_service( + prefill_instance, "/v1/chat/completions", kv_prepare_request, request_id + ) # yapf: disable + + # Perform kv recv and decoding stage + response_json = response.json() + kv_transfer_params = response_json.get("kv_transfer_params", {}) + if kv_transfer_params: + request["kv_transfer_params"] = kv_transfer_params + decode_instance = self.schedule(self.decode_cycler, is_prompt=False, request_len=total_length) + + try: + generator_d = self.forward_request( + "http://" + decode_instance + "/v1/chat/completions", request, request_id + ) + except HTTPException as http_exc: + self.exception_handler(prefill_instance, decode_instance, total_length) + raise http_exc + + if request.get("stream", False): + generator_class = self.generator + else: + # For stream=False request, cannot use P first token + generator_class = D_first_token_generator + final_generator = generator_class( + generator_d, + self, + decode_instance, + req_len=total_length, + ) + media_type = "text/event-stream" if request.get("stream", False) else "application/json" + + async def wrapped_generator(): + try: + async for chunk in final_generator: + yield chunk + except CancelledError: + logger.warning("[0] Client disconnected during create_completion (CancelledError)") + except Exception as e: + logger.error("[1] Exception in wrapped_generator: %s", str(e)) + raise + + return StreamingResponse(wrapped_generator(), media_type=media_type) + except Exception: + exc_info = sys.exc_info() + error_messages = [str(e) for e in exc_info if e] + print("Error occurred in disagg proxy server") + print(error_messages) + return StreamingResponse(content=iter(error_messages), media_type="application/json") + + def remove_instance_endpoint(self, instance_type, instance): + return + + +class RoundRobinSchedulingPolicy(SchedulingPolicy): + def __init__(self): + print("RoundRobinSchedulingPolicy") + super().__init__() + + def safe_next(self, cycler: itertools.cycle): + with self.lock: + return next(cycler) + + def schedule(self, cycler: itertools.cycle, request: Optional[dict[str, any]] = None) -> str: + return self.safe_next(cycler) + + +class LoadBalancedScheduler(SchedulingPolicy): + def __init__(self, prefill_instances: list[str], decode_instances: list[str]): + self.prefill_utils_counter = [0] * len(prefill_instances) + self.prefill_bs_counter = [0] * len(prefill_instances) + self.decode_kv_utils_counter = [0] * len(decode_instances) # KV cache utils + self.decode_bs_counter = [0] * len(decode_instances) + + self.prefill_instances = prefill_instances + self.decode_instances = decode_instances + print( + " LoadBalancedScheduler, prefill/decode instance is = ", + len(self.prefill_bs_counter), + len(self.decode_bs_counter), + ) + print(" LoadBalancedScheduler, self.prefill_instances =", self.prefill_instances) + print(" LoadBalancedScheduler, self.decode_instances =", self.decode_instances) + self.prefill_schedule_index = 0 + self.prefill_schedule_completion_index = 0 + self.decode_schedule_index = 0 + self.decode_schedule_completion_index = 0 + + super().__init__() + + def schedule( + self, + cycler: itertools.cycle, + is_prompt: int = None, + request_len: Optional[int] = None, + ) -> str: + with self.lock: + if is_prompt: + min_value = min(self.prefill_utils_counter) + min_index = self.prefill_utils_counter.index(min_value) + self.prefill_bs_counter[min_index] += 1 + self.prefill_utils_counter[min_index] += request_len + self.prefill_schedule_index += 1 + log_info_yellow( + f" instance = {min_index}, min_tokens = {min_value}" + ) + return self.prefill_instances[min_index] + else: + min_value = min(self.decode_bs_counter) + + if min_value == 0: + min_index = self.decode_bs_counter.index(min_value) + else: + min_indices = [i for i, val in enumerate(self.decode_bs_counter) if val == min_value] + min_index = min(min_indices, key=lambda i: self.decode_kv_utils_counter[i]) + + self.decode_bs_counter[min_index] += 1 + self.decode_kv_utils_counter[min_index] += request_len + self.decode_schedule_index += 1 + log_info_blue( + f" instance = {min_index}, min_batch = {min_value}" + ) + log_info_blue(f" decode_bs_counter: {self.decode_bs_counter}") + log_info_blue(f" decode_kv_utils_counter: {self.decode_kv_utils_counter}") + + return self.decode_instances[min_index] + + def schedule_completion( + self, + prefill_instance: str = None, + decode_instance: str = None, + req_len: int = None, + ): + with self.lock: + if prefill_instance: + index = self.prefill_instances.index(prefill_instance) + if self.prefill_bs_counter[index] == 0: + logger.warning("No alive requests for prefill instance, skipping...") + else: + self.prefill_schedule_completion_index += 1 + log_info_yellow( + f" " + f"instance = {index}, req_len={req_len}" + ) + + self.prefill_bs_counter[index] -= 1 + all_zero = True + for index, _ in enumerate(self.prefill_instances): + if self.prefill_bs_counter[index] != 0: + all_zero = False + break + if all_zero: + log_info_red("") + for index, _ in enumerate(self.prefill_instances): + self.prefill_utils_counter[index] = 0 + else: + index = self.prefill_instances.index(prefill_instance) + self.prefill_utils_counter[index] -= req_len + + if decode_instance: + index = self.decode_instances.index(decode_instance) + if self.decode_bs_counter[index] == 0: + logger.warning("No alive requests for decode instance, skipping...") + else: + self.decode_schedule_completion_index += 1 + log_info_blue( + f" " + f"instance = {index}, req_len={req_len}" + ) + + self.decode_bs_counter[index] -= 1 + all_zero = True + for index, _ in enumerate(self.decode_instances): + if self.decode_bs_counter[index] != 0: + all_zero = False + break + if all_zero: + log_info_red("") + self.decode_kv_utils_counter = [0] * len(self.decode_instances) + else: + index = self.decode_instances.index(decode_instance) + self.decode_kv_utils_counter[index] -= req_len + log_info_blue(f" decode_bs_counter: {self.decode_bs_counter}") + log_info_blue( + f" decode_kv_utils_counter: {self.decode_kv_utils_counter}" + ) + + +def parse_instance_spec(instance_spec): + """ + Parse instance specification in the format: host:port_range + Examples: + 192.168.1.100:8300-8307 -> [('192.168.1.100', 8300), ('192.168.1.100', 8301), ...] + localhost:8100 -> [('localhost', 8100)] + 192.168.1.100:8300,8301,8302 -> [('192.168.1.100', 8300), ('192.168.1.100', 8301), ('192.168.1.100', 8302)] + """ + if ":" not in instance_spec: + raise ValueError( + f"Invalid instance specification '{instance_spec}'. Expected format: host:port or host:port_range" + ) + + host, port_spec = instance_spec.rsplit(":", 1) + instances = [] + + # Handle port ranges (e.g., 8300-8307) + if "-" in port_spec: + start_port, end_port = port_spec.split("-", 1) + try: + start_port = int(start_port) + end_port = int(end_port) + if start_port > end_port: + raise ValueError(f"Invalid port range: {start_port} > {end_port}") + for port in range(start_port, end_port + 1): + instances.append(f"{host}:{port}") + except ValueError as e: + raise ValueError(f"Invalid port range '{port_spec}': {e}") + + # Handle comma-separated ports (e.g., 8300,8301,8302) + elif "," in port_spec: + ports = port_spec.split(",") + for port in ports: + try: + instances.append(f"{host}:{int(port.strip())}") + except ValueError: + raise ValueError(f"Invalid port '{port}' in specification '{instance_spec}'") + + # Handle single port + else: + try: + instances.append(f"{host}:{int(port_spec)}") + except ValueError: + raise ValueError(f"Invalid port '{port_spec}' in specification '{instance_spec}'") + + return instances + + +class ProxyServer: + def __init__( + self, + args: argparse.Namespace, + scheduling_policy: Optional[SchedulingPolicy] = None, + create_completion: Optional[Callable[[Request], StreamingResponse]] = None, + create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None, + ): + self.port = args.port + prefiller_instances = [] + if args.prefill: + for instance_spec in args.prefill: + prefiller_instances.extend(parse_instance_spec(instance_spec)) + decoder_instances = [] + if args.decode: + for instance_spec in args.decode: + decoder_instances.extend(parse_instance_spec(instance_spec)) + self.validate_parsed_serve_args(prefiller_instances, decoder_instances) + self.proxy_instance = Proxy( + prefill_instances=prefiller_instances, + decode_instances=decoder_instances, + model=args.model, + scheduling_policy=( + scheduling_policy(prefiller_instances, decoder_instances) + if scheduling_policy is not None + else RoundRobinSchedulingPolicy() + ), + custom_create_completion=create_completion, + custom_create_chat_completion=create_chat_completion, + ) + + def validate_parsed_serve_args(self, prefills: list, decodes: list): + if len(decodes) == 0: + raise ValueError("Please specify at least one decode node.") + if len(prefills): + self.validate_instances(prefills) + self.verify_model_config(prefills, args.model) + self.validate_instances(decodes) + self.verify_model_config(decodes, args.model) + + def validate_instances(self, instances: list): + for instance in instances: + if len(instance.split(":")) != 2: + raise ValueError(f"Invalid instance format: {instance}") + host, port = instance.split(":") + try: + if host != "localhost": + ipaddress.ip_address(host) + port = int(port) + if not (0 < port < 65536): + raise ValueError(f"Invalid port number in instance: {instance}") + except Exception as e: + raise ValueError(f"Invalid instance {instance}: {str(e)}") from e + + def verify_model_config(self, instances: list, model: str) -> None: + for instance in instances: + try: + response = requests.get(f"http://{instance}/v1/models") + if response.status_code == 200: + model_cur = response.json()["data"][0]["id"] + if model_cur != model: + raise ValueError(f"{instance} serves a different model: {model_cur} != {model}") + else: + raise ValueError(f"Cannot get model id from {instance}!") + except requests.RequestException as e: + raise ValueError(f"Error communicating with {instance}: {str(e)}") from e + + def run_server(self): + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.include_router(self.proxy_instance.router) + config = uvicorn.Config(app, host="0.0.0.0", port=self.port, loop="uvloop") + server = uvicorn.Server(config) + server.run() + + +if __name__ == "__main__": + # Todo: allow more config + parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") + parser.add_argument("--model", "-m", type=str, required=True, help="Model name") + + parser.add_argument( + "--prefill", + "-p", + type=str, + nargs="+", + help="List of prefill node URLs (host:port)", + ) + + parser.add_argument( + "--decode", + "-d", + type=str, + nargs="+", + help="List of decode node URLs (host:port)", + ) + + parser.add_argument( + "--port", + type=int, + default=8000, + help="Server port number", + ) + + parser.add_argument( + "--roundrobin", + action="store_true", + help="Use Round Robin scheduling for load balancing", + ) + + parser.add_argument( + "--bypass-proxy", + action="store_true", + help="Bypass HTTP/HTTPS proxy settings when connecting to prefill and decode instances", + ) + + args = parser.parse_args() + + # Set proxy bypass environment variables if requested + if args.bypass_proxy: + logger.info("Clearing proxy environment variables...") + os.environ["http_proxy"] = "" + os.environ["https_proxy"] = "" + os.environ["HTTP_PROXY"] = "" + os.environ["HTTPS_PROXY"] = "" + + if args.roundrobin: + proxy_server = ProxyServer(args=args) + else: + proxy_server = ProxyServer(args=args, scheduling_policy=LoadBalancedScheduler) + + proxy_server.run_server()