diff --git a/examples/online_serving/epd/launch.sh b/examples/online_serving/epd/launch.sh new file mode 100644 index 000000000000..d7f81b883f6a --- /dev/null +++ b/examples/online_serving/epd/launch.sh @@ -0,0 +1,114 @@ +#!/usr/bin/env bash +set -euo pipefail + +export VLLM_VERSION=0.10.0 + +wait_for_server() { + local port=$1 + timeout 12000 bash -c ' + until curl -s "http://localhost:'"$port"'/v1/chat/completions" > /dev/null; do + sleep 1 + done + ' && return 0 || return 1 +} + +MODEL="${MODEL:-Qwen2.5-VL-3B-Instruct}" +MODEL_NAME="${MODEL_NAME:-qwen2.5-vl-3b-instruct}" + +LOG_PATH=${LOG_PATH:-./logs} +mkdir -p "$LOG_PATH" + +ENCODE_PORT=19534 +PREFILL_DECODE_PORT=19535 +PROXY_PORT=10001 + +GPU_E=0 +GPU_PD=1 + +START_TIME=$(date +"%Y%m%d_%H%M%S") +ENC_LOG="$LOG_PATH/encoder.log" +PD_LOG="$LOG_PATH/pd.log" +PROXY_LOG="$LOG_PATH/proxy.log" +PID_FILE="./pid.txt" + +SHARED_STORAGE_PATH="/dev/shm/epd" + +############################################################################### +# Encoder worker +############################################################################### +start_encoder_worker() { + ASCEND_RT_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.0 \ + --port "$ENCODE_PORT" \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-seqs 128 \ + --max-model-len 3072 \ + --max-num-batched-tokens 40000 \ + --served-model-name "$MODEL_NAME" \ + --enforce-eager \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$SHARED_STORAGE_PATH"'", + "ec_max_num_scheduled_tokens": "4096" + } + }' \ + >"$ENC_LOG" 2>&1 & + + echo $! >> "$PID_FILE" +} + +start_encoder_worker + +############################################################################### +# Prefill / decode worker +############################################################################### +start_prefill_decode_worker() { + ASCEND_RT_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.95 \ + --port "$PREFILL_DECODE_PORT" \ + --no-enable-prefix-caching \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --max-model-len 3072 \ + --max-num-batched-tokens 40000 \ + --served-model-name "$MODEL_NAME" \ + --enforce-eager \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$SHARED_STORAGE_PATH"'" + } + }' \ + >"$PD_LOG" 2>&1 & + + echo $! >> "$PID_FILE" +} + +start_prefill_decode_worker + +# Wait until both workers are ready +# wait_for_server "$ENCODE_PORT" +# wait_for_server "$PREFILL_DECODE_PORT" + +############################################################################### +# Proxy +############################################################################### +start_proxy() { + python proxy.py \ + --host "127.0.0.1" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \ + >"$PROXY_LOG" 2>&1 & + + echo $! >> "$PID_FILE" +} + +start_proxy + +# wait_for_server "$PROXY_PORT" +echo "All services are up!" \ No newline at end of file diff --git a/examples/online_serving/epd/proxy.py b/examples/online_serving/epd/proxy.py new file mode 100644 index 000000000000..f3cefbf524da --- /dev/null +++ b/examples/online_serving/epd/proxy.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# api_proxy.py +import argparse +import asyncio +import copy +import logging +import random +import uuid +from collections.abc import AsyncIterator +from typing import Optional + +import aiohttp +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI() + +encode_session: Optional[aiohttp.ClientSession] = None +decode_session: Optional[aiohttp.ClientSession] = None + + +@app.on_event("startup") +async def startup_event(): + global encode_session, decode_session + encode_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=0), + timeout=aiohttp.ClientTimeout(total=100000), + ) + decode_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=0), + timeout=aiohttp.ClientTimeout(total=100000), + ) + + +@app.on_event("shutdown") +async def shutdown_event(): + global encode_session, decode_session + if encode_session: + await encode_session.close() + if decode_session: + await decode_session.close() + + +def has_mm_input(request_data: dict): + if "messages" not in request_data: + return False + for message in request_data["messages"]: + if not isinstance(message.get("content"), list): + continue + for content_item in message["content"]: + if content_item.get("type") in ["image_url", "audio_url", "input_audio"]: + return True + return False + + +async def forward_streaming_request( + request_data: dict, + request_id: str, + e_server_url: str, + pd_server_url: str, +) -> AsyncIterator[str]: + headers = {"x-request-id": request_id} + # Skip request to encoder instance if we don't have mm input + if has_mm_input(request_data): + encoder_request_data = copy.deepcopy(request_data) + encoder_request_data["max_tokens"] = 1 + encoder_request_data["stream"] = False + encoder_request_data.pop("stream_options", None) + if "max_completion_tokens" in encoder_request_data: + encoder_request_data["max_completion_tokens"] = 1 + task1 = asyncio.create_task( + encode_session.post( + f"{e_server_url}/v1/chat/completions", + json=encoder_request_data, + headers=headers, + ) + ) + try: + response = await task1 + if response.status != 200: + error_text = await response.text() + raise HTTPException( + status_code=response.status, + detail={"error": "Request failed", "message": error_text}, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail={"error": "Internal server error", "message": str(e)}, + ) from e + + # import time + # time.sleep(10) + try: + async with decode_session.post( + f"{pd_server_url}/v1/chat/completions", json=request_data, headers=headers + ) as response: + response.raise_for_status() + async for chunk in response.content.iter_chunked(128): + if chunk: + yield chunk.decode("utf-8", errors="ignore") + except Exception as e: + logger.error("Error in streaming: %s", e) + raise + + +async def forward_non_streaming_request( + request_data: dict, + request_id: str, + e_server_url: str, + pd_server_url: str, +) -> dict: + headers = {"x-request-id": request_id} + # Skip request to encoder instance if we don't have mm input + if has_mm_input(request_data): + encoder_request_data = copy.deepcopy(request_data) + encoder_request_data["max_tokens"] = 1 + if "max_completion_tokens" in encoder_request_data: + encoder_request_data["max_completion_tokens"] = 1 + # Start request to encode server + task1 = asyncio.create_task( + encode_session.post( + f"{e_server_url}/v1/chat/completions", + json=encoder_request_data, + headers=headers, + ) + ) + + try: + response = await task1 + if response.status != 200: + error_text = await response.text() + raise HTTPException( + status_code=response.status, + detail={"error": "Request failed", "message": error_text}, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail={"error": "Internal server error", "message": str(e)}, + ) from e + + try: + # Make request to decode server + async with decode_session.post( + f"{pd_server_url}/v1/chat/completions", json=request_data, headers=headers + ) as response2: + response2.raise_for_status() + result = await response2.json() + return result + except Exception as e: + logger.error("Error in non-streaming: %s", e) + raise + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + """Handle chat completion requests.""" + try: + e_instance = random.randint(0, len(app.state.e_urls) - 1) + pd_instance = random.randint(0, len(app.state.pd_urls) - 1) + e_server_url = app.state.e_urls[e_instance] + pd_server_url = app.state.pd_urls[pd_instance] + + request_data = await request.json() + request_id = request.headers.get("x-request-id") + if not request_id: + request_id = str(uuid.uuid4()) + is_streaming = request_data.get("stream", False) + if is_streaming: + return StreamingResponse( + forward_streaming_request( + request_data, request_id, e_server_url, pd_server_url + ), + media_type="text/event-stream", + ) + else: + result = await forward_non_streaming_request( + request_data, request_id, e_server_url, pd_server_url + ) + return JSONResponse(content=result) + except Exception as e: + logger.error("Error processing request: %s", e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.get("/v1/models") +async def list_models(): + try: + async with decode_session.get(f"{app.state.pd_urls[0]}/v1/models") as response: + response.raise_for_status() + return await response.json() + except Exception as e: + logger.error("Error fetching models: %s", e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + try: + + async def check_encode(): + try: + for e_url in app.state.e_urls: + async with encode_session.get(f"{e_url}/health") as response: + response.raise_for_status() + return True + except Exception: + return False + + async def check_decode(): + try: + for pd_url in app.state.pd_urls: + async with encode_session.get(f"{pd_url}/health") as response: + response.raise_for_status() + return True + except Exception: + return False + + encode_healthy, decode_healthy = await asyncio.gather( + check_encode(), check_decode(), return_exceptions=True + ) + + health_status = { + "proxy": "healthy", + "encode_servers": "healthy" if encode_healthy is True else "unhealthy", + "prefill_decode_servers": "healthy" + if decode_healthy is True + else "unhealthy", + } + + if not (encode_healthy is True and decode_healthy is True): + return JSONResponse(content=health_status, status_code=503) + + return health_status + + except Exception as e: + logger.error("Health check error: %s", e) + return JSONResponse( + content={"proxy": "unhealthy", "error": str(e)}, status_code=503 + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="API Proxy for distributed vLLM servers" + ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Proxy host") + parser.add_argument("--port", type=int, default=8000, help="Proxy port") + + parser.add_argument( + "--encode-servers-urls", + type=str, + required=True, + help="URLs of the encode server in comma separated format" + '(e.g., "http://localhost:8001,http://localhost:8002")', + ) + + parser.add_argument( + "--prefill-decode-servers-urls", + type=str, + required=True, + help="URLs of the prefill/decode servers in comma separated format" + '(e.g., "http://localhost:8003,http://localhost:8004")', + ) + + args = parser.parse_args() + app.state.e_urls = args.encode_servers_urls.split(",") + app.state.pd_urls = args.prefill_decode_servers_urls.split(",") + + logger.info("Starting API proxy on %s:%s with 1 worker", args.host, args.port) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + access_log=False, + loop="uvloop", + ) diff --git a/examples/online_serving/epd/send_request.py b/examples/online_serving/epd/send_request.py new file mode 100644 index 000000000000..8e73d9d8a436 --- /dev/null +++ b/examples/online_serving/epd/send_request.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import base64 +import time + +from openai import AsyncOpenAI + + +async def async_query_openai(query, model_path, port): + aclient = AsyncOpenAI( + base_url=f"http://localhost:{str(port)}/v1", + api_key="EMPTY", + timeout=100000, + ) + completion = await aclient.chat.completions.create( + model=model_path, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": query, + }, + ], + temperature=0.0, + top_p=0.1, + max_tokens=512, + ) + return completion.choices[0].message.content + + +async def async_process_queries(queries, model_path, port): + results = await asyncio.gather( + *(async_query_openai(query, model_path, port) for query in queries) + ) + return results + + +async def main(args): + # single query + image_path = args.image_path + with open(image_path, "rb") as f: + encoded_image = base64.b64encode(f.read()) + encoded_image_text = encoded_image.decode("utf-8") + image_base64 = f"data:image;base64,{encoded_image_text}" + query = [ + { + "type": "image_url", + "image_url": {"url": image_base64}, + }, + {"type": "text", "text": "What is shown in the image.?"}, + ] + bs = args.batch_size + queries = [query for i in range(bs)] + + start_time = time.time() + results = await async_process_queries(queries, args.model_path, args.port) + end_time = time.time() + for result in results: + print(result) + print("-" * 50) + print(f"Total time: {end_time - start_time:.2f} seconds") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="test") + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--image_path", type=str, default="./demo.jpeg") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--port", type=int, default=10001) + args, _ = parser.parse_known_args() + + asyncio.run(main(args)) diff --git a/examples/online_serving/epd/stop.sh b/examples/online_serving/epd/stop.sh new file mode 100644 index 000000000000..839b21ab09ed --- /dev/null +++ b/examples/online_serving/epd/stop.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# +# Reads pid.txt created by run_servers.sh and kills every process. +# + +set -euo pipefail + +PID_FILE="./pid.txt" +[[ -f "$PID_FILE" ]] || { + echo "No $PID_FILE found – nothing to stop." + exit 0 +} + +echo "Stopping processes listed in $PID_FILE …" + +while read -r pid; do + [[ -z "$pid" ]] && continue # skip blank lines + if kill -0 "$pid" 2>/dev/null; then + echo " → SIGTERM $pid" + kill "$pid" + # wait up to 5 s, escalate to SIGKILL if still alive + for _ in {1..5}; do + kill -0 "$pid" 2>/dev/null || break + sleep 1 + done + if kill -0 "$pid" 2>/dev/null; then + echo " → SIGKILL $pid" + kill -9 "$pid" || true + fi + else + echo " → PID $pid is already gone" + fi +done < "$PID_FILE" + +rm -f "$PID_FILE" +echo "Done." \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index f038cdd64c67..b07ac540ceab 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3923,6 +3923,107 @@ class KVEventsConfig: """ +ECProducer = Literal["ec_producer"] +ECConsumer = Literal["ec_consumer"] +ECRole = Literal[ECProducer, ECConsumer] + + +@config +@dataclass +class ECTransferConfig: + """Configuration for distributed EC cache transfer.""" + + ec_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + ec_buffer_device: Optional[str] = "cuda" + """The device used by ec connector to buffer the EC cache. + Currently only support 'cuda'.""" + + ec_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + ec_role: Optional[ECRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'ec_producer', 'ec_consumer'.""" + + ec_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for encoder, 1 for pd instance. + Currently only 1P1D is supported.""" + + ec_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + PyNcclConnector, this should be 2.""" + + ec_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + ec_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + ec_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + ec_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.ec_role is not None and self.ec_role not in get_args(ECRole): + raise ValueError(f"Unsupported ec_role: {self.ec_role}. " + f"Supported roles are {get_args(ECRole)}") + + if self.ec_connector is not None and self.ec_role is None: + raise ValueError("Please specify ec_disagg_role when ec_connector " + f"is set, supported roles are {get_args(ECRole)}") + + @property + def is_ec_transfer_instance(self) -> bool: + return self.ec_connector is not None and \ + self.ec_role in get_args(ECRole) + + @property + def is_ec_producer(self) -> bool: + return self.ec_connector is not None and \ + self.ec_role in get_args(ECProducer) + + @property + def is_ec_consumer(self) -> bool: + return self.ec_connector is not None and \ + self.ec_role in get_args(ECConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.ec_connector_extra_config.get(key, default) + + class CompilationLevel: # constants for the levels of the compilation process NO_COMPILATION = 0 @@ -4379,6 +4480,8 @@ class VllmConfig: """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None """The configurations for event publishing.""" + ec_transfer_config: Optional[ECTransferConfig] = None + """The configurations for distributed encoder cache transfer.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -4463,6 +4566,10 @@ def compute_hash(self) -> str: vllm_factors.append(self.kv_transfer_config.compute_hash()) else: vllm_factors.append("None") + if self.ec_transfer_config: + vllm_factors.append(self.ec_transfer_config.compute_hash()) + else: + vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): additional_config_hash = hashlib.md5( diff --git a/vllm/distributed/ec_transfer/__init__.py b/vllm/distributed/ec_transfer/__init__.py new file mode 100644 index 000000000000..9fb784268183 --- /dev/null +++ b/vllm/distributed/ec_transfer/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.distributed.ec_transfer.ec_transfer_state import ( + ensure_ec_transfer_initialized, get_ec_transfer, has_ec_transfer) + +__all__ = [ + "get_ec_transfer", + "ensure_ec_transfer_initialized", + "has_ec_transfer", +] diff --git a/vllm/distributed/ec_transfer/ec_connector/__init__.py b/vllm/distributed/ec_transfer/ec_connector/__init__.py new file mode 100644 index 000000000000..3257011d1a0a --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.distributed.ec_transfer.ec_connector.base import (ECConnectorBase, + ECConnectorRole) + +__all__ = ["ECConnectorRole", "ECConnectorBase"] diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py new file mode 100644 index 000000000000..ba9c03293444 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +ECConnectorBase Class for Distributed Encoder Cache & P2P Encoder cache +communication in V1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save Encoder cache. + check_caches_exist() - Check whether Encoder cache of requests exist + update_state_after_alloc() - update ECConnector state after + allocate. This will decide to load the cache or not + request_finished() - called when a request is finished, free the + cache with the requests + + Worker-side: runs in each worker, loads/saves Encoder Cache to/from + the Connector based on the metadata. + start_load_ec() - starts loading all ECs (maybe async) + wait_for_save() - blocks until all saves are done + + get_finished() - called with ids of finished requests, returns + ids of requests that have completed async sending/recving. +""" + +import enum +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ECConnectorOutput + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class ECConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +class ECConnectorMetadata(ABC): # noqa: B024 + """ + Abstract Metadata used to communicate between the + Scheduler ECConnector and Worker ECConnector. + """ + pass + + +class ECConnectorBase(ABC): + + def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): + self._connector_metadata: Optional[ECConnectorMetadata] = None + self._vllm_config = vllm_config + self._role = role + self._is_producer = vllm_config.ec_transfer_config.is_ec_producer + + @property + def role(self) -> ECConnectorRole: + return self._role + + @property + def is_producer(self) -> bool: + return self._is_producer + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata( + self, connector_metadata: ECConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + EC cache loading. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = None + + def _get_connector_metadata(self) -> ECConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + + # Should only be called while set to valid metadata. + assert self._connector_metadata is not None + return self._connector_metadata + + def register_caches( + self, + ec_caches: dict[str, torch.Tensor], + ): + """ + Initialize with the EC caches. + Args: + ec_caches: dictionary of encoder cache + """ + # TODO: Implement this later for P2P feature + return + + @abstractmethod + def start_load_caches(self, **kwargs) -> None: + """ + Start loading the cache from the connector to vLLM's encoder cache. + This is called before _gather_mm_embeddings for EC Connector + For EC the encoder_cache and mm_hash is store in kwargs + + Args: + **kwargs: additional arguments for the load operation + + """ + pass + + @abstractmethod + def save_caches(self, **kwargs) -> None: + """ + Save caches into connector + For EC the encoder_cache and mm_hash is store in kwargs + """ + pass + + @abstractmethod + def wait_for_save(self): + """ + Block until all the save operations is done. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens on the worker. + The scheduler process (via the Executors) will use this output + to track which workers are done. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + # ============================== + # Scheduler-side methods + # ============================== + + @abstractmethod + def check_caches_exist( + self, + request: "Request", + index: Optional[int] = None, + ) -> Union[bool, list[bool]]: + """ + Check if encoder cache exists for each mm data of requests. + + Args: + request (Request): the request object. + index (Optional[int]): the index of the request in the batch. + + Returns: + Union[bool, list[bool]]: True if cache exists for the specific + index, or a list of booleans indicating the existence of caches + for all indexes. + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", index: int): + """ + Update ECConnector state to decide allocate cache for requests + + Args: + request (Request): the request object. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> ECConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass + + def update_connector_output(self, connector_output: ECConnectorOutput): + """ + Update ECConnector state from worker-side connectors output. + + Args: + connector_output (ECConnectorOutput): the worker-side + connectors output. + """ + return + + def request_finished( + self, request: "Request") -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its freed the local + encoder cached. + + Returns: + True if the request is being saved/sent asynchronously and cached + should not be freed until the request_id is returned from + get_finished(). + """ + return False, None diff --git a/vllm/distributed/ec_transfer/ec_connector/factory.py b/vllm/distributed/ec_transfer/ec_connector/factory.py new file mode 100644 index 000000000000..2cd7239a094b --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/factory.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +from typing import TYPE_CHECKING, Callable + +# yapf: disable +import vllm.envs as envs +from vllm.distributed.ec_transfer.ec_connector.base import (ECConnectorBase, + ECConnectorRole) +from vllm.logger import init_logger + +# yapf: enable + +if TYPE_CHECKING: + from vllm.config import ECTransferConfig, VllmConfig + +logger = init_logger(__name__) + + +class ECConnectorFactory: + _registry: dict[str, Callable[[], type[ECConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[ECConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector( + cls, + config: "VllmConfig", + role: ECConnectorRole, + ) -> ECConnectorBase: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + ec_transfer_config = config.ec_transfer_config + connector_cls = cls.get_connector_class(ec_transfer_config) + logger.info("Creating v1 connector with name: %s and engine_id: %s", + connector_cls.__name__, ec_transfer_config.engine_id) + # NOTE(Kuntai): v1 connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + # - Should only be used inside the forward context & attention layer + # We build separately to enforce strict separation + return connector_cls(config, role) + + @classmethod + def get_connector_class( + cls, + ec_transfer_config: "ECTransferConfig") -> type[ECConnectorBase]: + """Get the connector class by name.""" + connector_name = ec_transfer_config.ec_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = ec_transfer_config.ec_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + return connector_cls + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. + +ECConnectorFactory.register_connector( + "ECSharedStorageConnector", + "vllm.distributed.ec_transfer.ec_connector.shared_storage_connector", + "ECSharedStorageConnector") diff --git a/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py b/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py new file mode 100644 index 000000000000..2e94923654b3 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import safetensors + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorBase, ECConnectorMetadata, ECConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MMMeta: + # mm_hash: str + # num_token: int + request_id: str = "" + input_ids: list[int] = None + + # @staticmethod + # def make_meta(mm_hash, num_token) -> "MMMeta": + # return MMMeta(mm_hash=mm_hash, num_token=num_token) + + @staticmethod + def make_mm_meta(request_id: str, input_ids: list[int]) -> "MMMeta": + return MMMeta(request_id=request_id, input_ids=input_ids) + + +@dataclass +class ECSharedStorageConnectorMetadata(ECConnectorMetadata): + mm_datas: list[MMMeta] + + def __init__(self): + self.mm_datas = [] + + def add_mm_data(self, mm_data: MMMeta): + self.mm_datas.append(mm_data) + + def add_mm_metadata(self, request_id: str, input_ids: list[int]): + self.mm_datas.append(MMMeta.make_mm_meta(request_id, input_ids)) + + +class ECSharedStorageConnector(ECConnectorBase): + # NOTE: This is Simple debug implementation of the EC connector. + # It save / load the EC cache to / from the disk. + + def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + # req_id -> index -> MMMeta + self._mm_datas_need_loads: dict[str, int] = {} + self._mm_datas: dict[str, list[int]] = {} + transfer_config = vllm_config.ec_transfer_config + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.debug(transfer_config) + logger.debug("Shared storage path is %s", self._storage_path) + + def start_load_caches(self, **kwargs) -> None: + """Start loading the EC cache from the connector buffer to worker + encoder_cache + + Args: + **kwargs: additional arguments for the load operation + """ + + # Get the metadata + metadata: ECConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, ECSharedStorageConnectorMetadata) + encoder_cache = kwargs.get("encoder_cache") # returns None if missing + assert encoder_cache is not None + if metadata is None: + logger.warning( + "In connector.start_load_caches, but the connector metadata " + "is None") + return + + for mm_data in metadata.mm_datas: + for input_id in mm_data.input_ids: + if input_id in encoder_cache.get(mm_data.request_id, {}): + continue + filename = self._generate_filename_debug( + f"{mm_data.request_id}_{input_id}") + if not os.path.exists(filename): + logger.warning("Encoder cache file %s does not exist", + filename) + continue + ec_cache = safetensors.torch.load_file( + filename)["ec_cache"].npu() + if mm_data.request_id not in encoder_cache: + encoder_cache[mm_data.request_id] = {} + encoder_cache[mm_data.request_id][input_id] = ec_cache + logger.debug( + "Success load encoder cache for request_id %s, input_id %d", + mm_data.request_id, input_id) + + def save_caches(self, **kwargs) -> None: + """Start saving the EC cache for each mm_datas from encoder cache + + Args: + **kwargs: additional arguments for the save operation. + """ + # Return if it is PD Instance + if not self.is_producer: + return + encoder_cache = kwargs.get("encoder_cache") + mm_hash = kwargs.get("mm_hash") + assert encoder_cache is not None + if mm_hash: + filename = self._generate_filename_debug(mm_hash) + ec_cache = encoder_cache[mm_hash] + else: + request_id = kwargs.get("request_id") + input_id = kwargs.get("input_id") + filename = self._generate_filename_debug( + f"{request_id}_{input_id}") + ec_cache = encoder_cache[request_id][input_id][0:3, :] + tensors = {"ec_cache": ec_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + logger.debug( + "Save cache successful for mm_hash %s, request_id %s, input_id %s", + mm_hash, request_id, input_id) + + def wait_for_save(self): + return + + def check_caches_exist( + self, + request: "Request", + index: Optional[int] = None, + ) -> Union[bool, list[bool]]: + """ + Check if cache exist externally for each mm_data of request + + Args: + request (Request): the request object. + index (Optional[int]): the index of the request in the batch. + + Returns: + List of bool indicate that ith mm_data exist in cache or not + """ + result = [] + request_id = request.request_id + if index is not None: + return self._found_match_for_mm_data(f"{request_id}_{index}") + + for input_id in range(len(request.mm_positions)): + if self._found_match_for_mm_data(f"{request_id}_{input_id}"): + result.append(True) + else: + result.append(False) + + # for mm_hash in request.mm_hashes: + # result.append(self._found_match_for_mm_data(mm_hash)) + return result + + def update_state_after_alloc( + self, + request: "Request", + index: int, + ) -> None: + """ + Update ECConnector state after encoder cache allocation. + """ + # mm_hash = request.mm_hashes[index] + # num_encoder_token = request.get_num_encoder_tokens(index) + # # Insert mm_hash only if this block has not been recorded yet. + # self._mm_datas_need_loads[mm_hash] = num_encoder_token + self._mm_datas.setdefault(request.request_id, []).append(index) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> ECConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + This only build for load mm_data only + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = ECSharedStorageConnectorMetadata() + for mm_data in self._mm_datas: + meta.add_mm_metadata(mm_data, self._mm_datas[mm_data]) + self._mm_datas.clear() + # for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): + # meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) + # self._mm_datas_need_loads.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_mm_data(self, mm_hash) -> bool: + """Check if the cache is hit for the request. + """ + filename = self._generate_filename_debug(mm_hash) + return os.path.exists(filename) + + def _generate_foldername_debug( + self, + mm_hash: str, + create_folder: bool = True, + ) -> str: + """ + Return the folder in which the cache for this mm_hash lives. + If `create_folder` is True (default) the directory is created + recursively the first time it is needed. + """ + # foldername = os.path.join(self._storage_path, mm_hash) + foldername = self._storage_path + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug(self, mm_hash: str) -> str: + """ + Return the full path of the safetensors file for this mm_hash. + Ensures the parent directory exists because + `_generate_foldername_debug` is called with its default + (`create_folder=True`). + """ + foldername = self._generate_foldername_debug(mm_hash, False) + return os.path.join(foldername, f"{mm_hash}_encoder_cache.safetensors") diff --git a/vllm/distributed/ec_transfer/ec_transfer_state.py b/vllm/distributed/ec_transfer/ec_transfer_state.py new file mode 100644 index 000000000000..91a8c80d1e96 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_transfer_state.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.distributed.ec_transfer.ec_connector.base import (ECConnectorBase, + ECConnectorRole) +from vllm.distributed.ec_transfer.ec_connector.factory import ( + ECConnectorFactory) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_EC_CONNECTOR_AGENT: Optional[ECConnectorBase] = None + + +def get_ec_transfer() -> ECConnectorBase: + assert _EC_CONNECTOR_AGENT is not None, ( + "disaggregated EC cache is not initialized") + return _EC_CONNECTOR_AGENT + + +def has_ec_transfer() -> bool: + return _EC_CONNECTOR_AGENT is not None + + +def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize EC cache connector. + """ + + global _EC_CONNECTOR_AGENT + + if vllm_config.ec_transfer_config is None: + return + + if (vllm_config.ec_transfer_config.is_ec_transfer_instance + and _EC_CONNECTOR_AGENT is None): + if envs.VLLM_USE_V1: + _EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector( + config=vllm_config, role=ECConnectorRole.WORKER) + else: + raise ValueError("V0 is no longer supported") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aec75f82631a..1e28f5825be3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -24,15 +24,15 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, DecodingConfig, DetailedTraceModules, Device, DeviceConfig, - DistributedExecutorBackend, GuidedDecodingBackend, - GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LoadFormat, - LogprobsMode, LoRAConfig, ModelConfig, ModelDType, - ModelImpl, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs, - get_field) + DistributedExecutorBackend, ECTransferConfig, + GuidedDecodingBackend, GuidedDecodingBackendV1, + HfOverrides, KVEventsConfig, KVTransferConfig, + LoadConfig, LoadFormat, LogprobsMode, LoRAConfig, + ModelConfig, ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -414,6 +414,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None + ec_transfer_config: Optional[ECTransferConfig] = None + generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode override_generation_config: dict[str, Any] = \ @@ -831,6 +833,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **vllm_kwargs["kv_transfer_config"]) vllm_group.add_argument('--kv-events-config', **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--ec-transfer-config", + **vllm_kwargs["ec_transfer_config"]) vllm_group.add_argument("--compilation-config", "-O", **vllm_kwargs["compilation_config"]) vllm_group.add_argument("--additional-config", @@ -1288,6 +1292,7 @@ def create_engine_config( compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, + ec_transfer_config=self.ec_transfer_config, additional_config=self.additional_config, ) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index dcab00822870..97bcae0bcbfc 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -30,7 +30,8 @@ def get_num_image_tokens( image_width: int, image_height: int, ) -> int: - return self.get_patch_grid_length()**2 + 1 + return 2 + # return self.get_patch_grid_length()**2 + 1 def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 48ec611df12d..1d94cc987de4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -343,14 +343,16 @@ def __init__(self, ) else: self.embed_tokens = PPMissingLayer() - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) + + if not vllm_config.ec_transfer_config.is_ec_producer: + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -416,6 +418,10 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + if "layers." in name: + layer_index = extract_layer_index(name) + if layer_index >= 37: + continue if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 979d789b330c..b41daea982fe 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -416,6 +416,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: multimodal_projector_bias=projector_bias, quant_config=quant_config, prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.vllm_config = vllm_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config. @@ -644,5 +645,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, vllm_config=self.vllm_config) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 62deb68035b9..7ba1aab442c0 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -109,6 +109,7 @@ def __init__( skip_prefixes: Optional[list[str]] = None, skip_substrs: Optional[list[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None, + vllm_config=None, ) -> None: super().__init__() @@ -118,6 +119,7 @@ def __init__( self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] # update default skip_substrs self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS + self.vllm_config = vllm_config def _groupby_prefix( self, @@ -285,8 +287,14 @@ def load_weights( if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name - weights = ((name, weight) for name, weight in weights - if not self._can_skip(name)) + if (self.vllm_config is not None + and self.vllm_config.ec_transfer_config.is_ec_producer): + weights = ((name, weight) for name, weight in weights + if not self._can_skip(name) + and not name.startswith("language_model.model.layers.")) + else: + weights = ((name, weight) for name, weight in weights + if not self._can_skip(name)) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index d34f39327805..2d16f7990ed3 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -10,6 +10,8 @@ import numpy as np import numpy.typing as npt + from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorMetadata) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) from vllm.lora.request import LoRARequest @@ -155,3 +157,6 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None + + # Encoder Cache Connector metadata + ec_connector_metadata: Optional[ECConnectorMetadata] = None \ No newline at end of file diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..cd5275632ca8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -10,6 +10,9 @@ from typing import Any, Optional, Union from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole +from vllm.distributed.ec_transfer.ec_connector.factory import ( + ECConnectorFactory) from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) @@ -91,6 +94,16 @@ def __init__( self.parallel_config.data_parallel_rank, ) + self.ec_connector = None + self.ec_max_num_scheduled_tokens = 0 + if self.vllm_config.ec_transfer_config is not None: + self.ec_connector = ECConnectorFactory.create_connector( + config=self.vllm_config, role=ECConnectorRole.SCHEDULER) + transfer_config = self.vllm_config.ec_transfer_config + self.ec_max_num_scheduled_tokens = int( + transfer_config.get_from_extra_config( + "ec_max_num_scheduled_tokens", 0)) + num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 @@ -189,7 +202,8 @@ def schedule(self) -> SchedulerOutput: req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} num_scheduled_tokens: dict[str, int] = {} - token_budget = self.max_num_scheduled_tokens + token_budget = max(self.max_num_scheduled_tokens, \ + self.ec_max_num_scheduled_tokens) # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_budget = self.max_num_encoder_input_tokens @@ -224,7 +238,8 @@ def schedule(self) -> SchedulerOutput: new_encoder_budget = encoder_budget if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( + new_encoder_budget, external_load_encoder_input + ) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, encoder_budget) @@ -315,6 +330,10 @@ def schedule(self) -> SchedulerOutput: for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + if (self.ec_connector is not None and external_load_encoder_input): + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + self.ec_connector.update_state_after_alloc(request, i) # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -429,7 +448,7 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget + new_encoder_budget, external_load_encoder_input ) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, encoder_budget) @@ -505,6 +524,12 @@ def schedule(self) -> SchedulerOutput: for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Allocate for external load encoder cache + if (self.ec_connector is not None + and external_load_encoder_input): + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + self.ec_connector.update_state_after_alloc(request, i) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -512,7 +537,10 @@ def schedule(self) -> SchedulerOutput: # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) - assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert total_num_scheduled_tokens <= max( + self.max_num_scheduled_tokens, + self.ec_max_num_scheduled_tokens, + ) assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs # Since some requests in the RUNNING queue may not be scheduled in @@ -580,6 +608,10 @@ def schedule(self) -> SchedulerOutput: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) + if self.ec_connector is not None: + meta = self.ec_connector.build_connector_meta(scheduler_output) + scheduler_output.ec_connector_metadata = meta + self._update_after_schedule(scheduler_output) return scheduler_output @@ -668,7 +700,7 @@ def _try_schedule_encoder_inputs( num_computed_tokens: int, num_new_tokens: int, encoder_budget: int, - ) -> tuple[list[int], int, int]: + ) -> tuple[list[int], int, int, list[int]]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -689,11 +721,13 @@ def _try_schedule_encoder_inputs( blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_budget + return [], num_new_tokens, encoder_budget, [] encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None assert len(mm_positions) > 0 + external_load_encoder_input = [] + for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -741,9 +775,15 @@ def _try_schedule_encoder_inputs( num_new_tokens = 0 break + if (self.ec_connector is not None + and self.ec_connector.check_caches_exist(request, i)): + external_load_encoder_input.append(i) + continue + encoder_budget -= num_encoder_tokens encoder_inputs_to_schedule.append(i) - return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + return (encoder_inputs_to_schedule, num_new_tokens, encoder_budget, + external_load_encoder_input) def update_from_output( self, @@ -767,7 +807,7 @@ def update_from_output( stopped_running_reqs: set[Request] = set() stopped_preempted_reqs: set[Request] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): - assert num_tokens_scheduled > 0 + # assert num_tokens_scheduled > 0 request = self.requests.get(req_id) if request is None: # The request is already finished. This can happen if the diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..22e416ee93d4 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -2,10 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional import torch +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + class LogprobsLists(NamedTuple): @@ -71,6 +74,13 @@ class SamplerOutput: logprobs_tensors: Optional[LogprobsTensors] +@dataclass +class ECConnectorOutput: + # [mm_hash] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass @@ -108,6 +118,8 @@ class ModelRunnerOutput: finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + ec_connector_output: Optional[ECConnectorOutput] = None + # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None @@ -122,3 +134,40 @@ class ModelRunnerOutput: finished_sending=None, finished_recving=None, num_nans_in_logits=None) + + +def make_empty_encoder_model_runner_output( + scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: + """ + Create a ModelRunnerOutput stub that contains the correct + per-request bookkeeping but no generated data yet. + """ + if not scheduler_output.num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + # Convert to list so we get a deterministic, indexable sequence + req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys()) + + # Give every request its own contiguous index + req_id_to_index: dict[str, int] = { + rid: idx + for idx, rid in enumerate(req_ids) + } + + # No tokens generated yet ⇒ one empty list per request + sampled_token_ids: list[list[int]] = [[0] for _ in req_ids] + + # Pooler outputs are not available yet ⇒ use None placeholders + pooler_output: list[Optional[torch.Tensor]] = [None for _ in req_ids] + + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + ec_connector_output=None, + num_nans_in_logits=None, + ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 85f5dcb92eb4..9d9e98c5dade 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -168,9 +168,10 @@ def get_finished_reason(self) -> Union[FinishReason, None]: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_tokens(self, input_id: int) -> int: - assert input_id < len(self.mm_positions) - num_tokens = self.mm_positions[input_id].length - return num_tokens + # assert input_id < len(self.mm_positions) + # num_tokens = self.mm_positions[input_id].length + # return num_tokens + return 1 @property def use_structured_output(self) -> bool: diff --git a/vllm/v1/worker/ec_connector_model_runner_mixin.py b/vllm/v1/worker/ec_connector_model_runner_mixin.py new file mode 100644 index 000000000000..49d72507e4ca --- /dev/null +++ b/vllm/v1/worker/ec_connector_model_runner_mixin.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Define EC connector functionality mixin for model runners. +""" +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Generator # noqa: UP035 +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase +from vllm.logger import init_logger +from vllm.v1.outputs import ECConnectorOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU) +class ECConnectorModelRunnerMixin: + + @staticmethod + def maybe_setup_ec_connector(scheduler_output: "SchedulerOutput"): + # Set up EC connector for load cache + if has_ec_transfer(): + ec_connector = get_ec_transfer() + assert isinstance(ec_connector, ECConnectorBase) + assert scheduler_output.ec_connector_metadata is not None + ec_connector.bind_connector_metadata( + scheduler_output.ec_connector_metadata) + ec_connector.start_load_caches() + + @staticmethod + def maybe_save_ec_to_connector( + encoder_cache: dict[str, torch.Tensor], + mm_hash: Optional[str] = None, + request_id: Optional[str] = None, + input_id: Optional[int] = None, + ): + if not has_ec_transfer(): + logger.debug("Not have ec transfer please check") + return + connector = get_ec_transfer() + connector.save_caches(encoder_cache=encoder_cache, + mm_hash=mm_hash, + request_id=request_id, + input_id=input_id) + + @staticmethod + def maybe_wait_for_ec_save() -> None: + if has_ec_transfer(): + get_ec_transfer().wait_for_save() + + @staticmethod + def get_finished_ec_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_ec_transfer(): + return get_ec_transfer().get_finished( + scheduler_output.finished_req_ids) + return None, None + + @staticmethod + def maybe_get_ec_connector_output( + scheduler_output: "SchedulerOutput", + **kwargs, + ) -> AbstractContextManager[Optional[ECConnectorOutput]]: + return ECConnectorModelRunnerMixin._get_ec_connector_output( + scheduler_output, ** + kwargs) if has_ec_transfer() else nullcontext() + + # This context manager must be used within an active forward context. + # It encapsulates the entire EC conector lifecycle within execute_model + @staticmethod + @contextmanager + def _get_ec_connector_output( + scheduler_output: "SchedulerOutput", + wait_for_save: bool = True, + **kwargs, + ) -> Generator[ECConnectorOutput, None, None]: + output = ECConnectorOutput() + + ec_connector = get_ec_transfer() + assert isinstance(ec_connector, ECConnectorBase) + assert scheduler_output.ec_connector_metadata is not None + ec_connector.bind_connector_metadata( + scheduler_output.ec_connector_metadata) + + ec_connector.start_load_caches(**kwargs) + try: + yield output + finally: + if wait_for_save: + ec_connector.wait_for_save() + + output.finished_sending, output.finished_recving = ( + ec_connector.get_finished(scheduler_output.finished_req_ids)) + + ec_connector.clear_connector_metadata()