From 9352f1cbd5e5c84de66b22c184d0ffcfdede4af1 Mon Sep 17 00:00:00 2001 From: herotai214 Date: Wed, 10 Dec 2025 07:14:32 +0000 Subject: [PATCH 1/2] [WIP][Feature] Nixl ECConnector Signed-off-by: herotai214 --- .../disagg_1e1p1d_example.sh | 22 +- .../disagg_1e1pd_example.sh | 104 +- .../disagg_1e1pd_example_sharedstorage.sh | 190 +++ .../disaggregated_encoder/disagg_epd_proxy.py | 64 +- .../integration/run_epd_correctness_test.sh | 37 +- .../ec_transfer/ec_connector/factory.py | 7 + .../ec_connector/nixl_connector.py | 1188 +++++++++++++++++ vllm/entrypoints/openai/protocol.py | 23 + vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_completion.py | 4 + vllm/entrypoints/openai/serving_tokens.py | 1 + vllm/envs.py | 4 + vllm/outputs.py | 4 + vllm/v1/core/sched/scheduler.py | 35 +- vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/output_processor.py | 7 +- vllm/v1/outputs.py | 1 + vllm/v1/request.py | 12 + .../worker/ec_connector_model_runner_mixin.py | 6 + vllm/v1/worker/gpu_model_runner.py | 118 +- 20 files changed, 1733 insertions(+), 96 deletions(-) create mode 100644 examples/online_serving/disaggregated_encoder/disagg_1e1pd_example_sharedstorage.sh create mode 100644 vllm/distributed/ec_transfer/ec_connector/nixl_connector.py diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh index 57489df64f51..882d98a53f1b 100644 --- a/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh @@ -92,7 +92,10 @@ mkdir -p $EC_SHARED_STORAGE_PATH ############################################################################### # Encoder worker ############################################################################### -CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ +CUDA_VISIBLE_DEVICES="$GPU_E" \ +VLLM_DEBUG_DUMP_PATH=$LOG_PATH \ +VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5569 \ +vllm serve "$MODEL" \ --gpu-memory-utilization 0.01 \ --port "$ENCODE_PORT" \ --enforce-eager \ @@ -102,11 +105,8 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_producer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_producer" }' \ >"${ENC_LOG}" 2>&1 & @@ -116,8 +116,10 @@ PIDS+=($!) # Prefill worker ############################################################################### CUDA_VISIBLE_DEVICES="$GPU_P" \ +VLLM_DEBUG_DUMP_PATH=$LOG_PATH \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \ +VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5579 \ vllm serve "$MODEL" \ --gpu-memory-utilization 0.7 \ --port "$PREFILL_PORT" \ @@ -126,11 +128,8 @@ vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_consumer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_consumer" }' \ --kv-transfer-config '{ "kv_connector": "NixlConnector", @@ -144,6 +143,7 @@ PIDS+=($!) # Decode worker ############################################################################### CUDA_VISIBLE_DEVICES="$GPU_D" \ +VLLM_DEBUG_DUMP_PATH=$LOG_PATH \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \ vllm serve "$MODEL" \ diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh index 6073e0580b11..6d28cca25ab2 100644 --- a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh @@ -14,7 +14,7 @@ ENCODE_PORT="${ENCODE_PORT:-19534}" PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}" PROXY_PORT="${PROXY_PORT:-10001}" -GPU_E="${GPU_E:-0}" +GPU_E="${GPU_E:-2}" GPU_PD="${GPU_PD:-1}" EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}" @@ -86,7 +86,10 @@ mkdir -p $EC_SHARED_STORAGE_PATH ############################################################################### # Encoder worker ############################################################################### -CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ +CUDA_VISIBLE_DEVICES="$GPU_E" \ +VLLM_DEBUG_DUMP_PATH=$LOG_PATH/dump \ +VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5569 \ +vllm serve "$MODEL" \ --gpu-memory-utilization 0.01 \ --port "$ENCODE_PORT" \ --enforce-eager \ @@ -96,11 +99,8 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_producer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_producer" }' \ >"${ENC_LOG}" 2>&1 & @@ -109,7 +109,10 @@ PIDS+=($!) ############################################################################### # Prefill+Decode worker ############################################################################### -CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ +VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5579 \ +VLLM_DEBUG_DUMP_PATH=$LOG_PATH/dump \ +CUDA_VISIBLE_DEVICES="$GPU_PD" \ +vllm serve "$MODEL" \ --gpu-memory-utilization 0.7 \ --port "$PREFILL_DECODE_PORT" \ --enforce-eager \ @@ -117,11 +120,8 @@ CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_consumer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_consumer" }' \ >"${PD_LOG}" 2>&1 & @@ -147,40 +147,44 @@ PIDS+=($!) wait_for_server $PROXY_PORT echo "All services are up!" -############################################################################### -# Benchmark -############################################################################### -echo "Running benchmark (stream)..." -vllm bench serve \ - --model $MODEL \ - --backend openai-chat \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path lmarena-ai/VisionArena-Chat \ - --seed 0 \ - --num-prompts $NUM_PROMPTS \ - --port $PROXY_PORT - -PIDS+=($!) - -############################################################################### -# Single request with local image -############################################################################### -echo "Running single request with local image (non-stream)..." -curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "'${MODEL}'", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}}, - {"type": "text", "text": "What is in this image?"} - ]} - ] - }' - - -# cleanup -echo "cleanup..." -cleanup \ No newline at end of file +# ############################################################################### +# # Benchmark +# ############################################################################### +# echo "Running benchmark (stream)..." +# vllm bench serve \ +# --model $MODEL \ +# --backend openai-chat \ +# --endpoint /v1/chat/completions \ +# --dataset-name hf \ +# --dataset-path lmarena-ai/VisionArena-Chat \ +# --seed 0 \ +# --num-prompts $NUM_PROMPTS \ +# --save-result \ +# --save-detailed \ +# --result-dir $LOG_PATH \ +# --result-filename ePD_nixl_$(date +"%Y%m%d_%H%M%S").json \ +# --port $PROXY_PORT + +# PIDS+=($!) + +# # ############################################################################### +# # # Single request with local image +# # ############################################################################### +# # echo "Running single request with local image (non-stream)..." +# # curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \ +# # -H "Content-Type: application/json" \ +# # -d '{ +# # "model": "'${MODEL}'", +# # "messages": [ +# # {"role": "system", "content": "You are a helpful assistant."}, +# # {"role": "user", "content": [ +# # {"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}}, +# # {"type": "text", "text": "What is in this image?"} +# # ]} +# # ] +# # }' + + +# # cleanup +# echo "cleanup..." +# cleanup \ No newline at end of file diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example_sharedstorage.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example_sharedstorage.sh new file mode 100644 index 000000000000..55d4d5408a5f --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example_sharedstorage.sh @@ -0,0 +1,190 @@ +#!/bin/bash +set -euo pipefail + +declare -a PIDS=() + +############################################################################### +# Configuration -- override via env before running +############################################################################### +MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}" +LOG_PATH="${LOG_PATH:-./logs}" +mkdir -p $LOG_PATH + +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}" +PROXY_PORT="${PROXY_PORT:-10001}" + +GPU_E="${GPU_E:-2}" +GPU_PD="${GPU_PD:-1}" + +EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}" +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout + +NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark + +############################################################################### +# Helpers +############################################################################### +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +START_TIME=$(date +"%Y%m%d_%H%M%S") +ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log +PD_LOG=$LOG_PATH/pd_${START_TIME}.log +PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log + +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:$port/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Stopping everything…" + trap - INT TERM USR1 # prevent re-entrancy + + # Kill all tracked PIDs + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" 2>/dev/null + fi + done + + # Wait a moment for graceful shutdown + sleep 2 + + # Force kill any remaining processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -9 "$pid" 2>/dev/null + fi + done + + # Kill the entire process group as backup + kill -- -$$ 2>/dev/null + + echo "All processes stopped." + exit 0 +} + +trap cleanup INT +trap cleanup USR1 +trap cleanup TERM + +# clear previous cache +echo "remove previous ec cache folder" +rm -rf $EC_SHARED_STORAGE_PATH + +echo "make ec cache folder" +mkdir -p $EC_SHARED_STORAGE_PATH + +############################################################################### +# Encoder worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.01 \ + --port "$ENCODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 114688 \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${ENC_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Prefill+Decode worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.7 \ + --port "$PREFILL_DECODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${PD_LOG}" 2>&1 & + +PIDS+=($!) + +# Wait for workers +wait_for_server $ENCODE_PORT +wait_for_server $PREFILL_DECODE_PORT + +############################################################################### +# Proxy +############################################################################### +python disagg_epd_proxy.py \ + --host "0.0.0.0" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "disable" \ + --decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \ + >"${PROXY_LOG}" 2>&1 & + +PIDS+=($!) + +wait_for_server $PROXY_PORT +echo "All services are up!" + +############################################################################### +# Benchmark +############################################################################### +echo "Running benchmark (stream)..." +vllm bench serve \ + --model $MODEL \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --seed 0 \ + --num-prompts $NUM_PROMPTS \ + --save-result \ + --save-detailed \ + --result-dir $LOG_PATH \ + --result-filename ePD_nixl_shared_$(date +"%Y%m%d_%H%M%S").json \ + --port $PROXY_PORT + +PIDS+=($!) + +############################################################################### +# Single request with local image +############################################################################### +echo "Running single request with local image (non-stream)..." +curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'${MODEL}'", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}}, + {"type": "text", "text": "What is in this image?"} + ]} + ] + }' + + +# cleanup +echo "cleanup..." +cleanup \ No newline at end of file diff --git a/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py b/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py index b5f99683c2bf..7b2df5cb7037 100644 --- a/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py +++ b/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py @@ -77,18 +77,22 @@ async def fanout_encoder_primer( orig_request: dict, e_urls: list[str], req_id: str, -) -> None: +) -> dict[str, dict]: """ 1. Build one request *per MM item* with all text removed. 2. Send them concurrently to the encode cluster. 3. Raise if any of them fails. + 4. Collect and aggregate ec_transfer_params by mm_hash. + + Returns: + dict mapping mm_hash to ec_transfer_params dict """ logger.info("[%s] Processing multimodal items...", req_id) mm_items = extract_mm_items(orig_request) if not mm_items: logger.info("[%s] No multimodal items, skipping encoder", req_id) - return # nothing to do + return {} # nothing to do logger.info("[%s] got %d multimodal items...", req_id, len(mm_items)) @@ -122,6 +126,9 @@ async def fanout_encoder_primer( results = await asyncio.gather(*tasks, return_exceptions=True) + # Aggregate ec_transfer_params by mm_hash + aggregated_ec_transfer_params: dict[str, dict] = {} + # Fail fast if any sub-request failed for idx, r in enumerate(results): if isinstance(r, Exception): @@ -152,10 +159,56 @@ async def fanout_encoder_primer( detail=f"Encoder request failed: {detail}", ) + # Extract ec_transfer_params from encoder response + try: + response_json = await r.json() + encoder_ec_transfer_params = response_json.get("ec_transfer_params") + if encoder_ec_transfer_params: + # encoder_ec_transfer_params is a dict keyed by mm_hash + # Format: {mm_hash: {do_remote_encode, num_encoder_tokens, remote_engine_id, ...}} + for mm_hash, mm_hash_params in encoder_ec_transfer_params.items(): + # Store params for this mm_hash + aggregated_ec_transfer_params[mm_hash] = { + "do_remote_encode": mm_hash_params.get("do_remote_encode", True), + "num_encoder_tokens": mm_hash_params.get("num_encoder_tokens"), + "mm_base_addr": mm_hash_params.get("mm_base_addr"), + "remote_engine_id": mm_hash_params.get("remote_engine_id"), + "remote_host": mm_hash_params.get("remote_host"), + "remote_port": mm_hash_params.get("remote_port"), + "tp_size": mm_hash_params.get("tp_size", 1), + } + logger.debug( + "[%s] Collected ec_transfer_params for mm_hash %s from encoder #%d", + req_id, + mm_hash, + idx, + ) + except Exception as e: + logger.warning( + "[%s] Failed to extract ec_transfer_params from encoder response #%d: %s", + req_id, + idx, + str(e), + ) + logger.info( - "[%s] All %d encoder requests completed successfully", req_id, len(mm_items) + "[%s] All %d encoder requests completed successfully, collected %d mm_hashes", + req_id, + len(mm_items), + len(aggregated_ec_transfer_params), ) + # Add aggregated ec_transfer_params to request data for prefill + req_data = orig_request + if aggregated_ec_transfer_params: + req_data["ec_transfer_params"] = aggregated_ec_transfer_params + logger.info( + "[%s] Added aggregated_ec_transfer_params for %d mm_hashes to prefill request", + req_id, + len(aggregated_ec_transfer_params), + ) + + return req_data async def maybe_prefill( req_data: dict, @@ -320,7 +373,7 @@ async def forward_non_stream( ) -> dict: try: # Step 1: Process through Encoder instance (if has MM input) - await fanout_encoder_primer(req_data, e_urls, req_id) + req_data = await fanout_encoder_primer(req_data, e_urls, req_id) # Step 2: Process through Prefill instance req_data = await maybe_prefill(req_data, p_url, req_id) @@ -329,6 +382,9 @@ async def forward_non_stream( logger.info("[%s] Forwarding to decode: %s", req_id, d_url) headers = {"x-request-id": req_id} + # logger.debug(f"hero: make non stream tokens become 3 only") + # req_data["max_tokens"] = 3 + # Non-streaming response async with decode_session.post( f"{d_url}/v1/chat/completions", json=req_data, headers=headers diff --git a/tests/v1/ec_connector/integration/run_epd_correctness_test.sh b/tests/v1/ec_connector/integration/run_epd_correctness_test.sh index 55dd39c0a957..eaebfa489223 100644 --- a/tests/v1/ec_connector/integration/run_epd_correctness_test.sh +++ b/tests/v1/ec_connector/integration/run_epd_correctness_test.sh @@ -33,7 +33,7 @@ GPU_E="${GPU_E:-0}" GPU_P="${GPU_P:-1}" GPU_D="${GPU_D:-2}" GPU_SINGLE="${GPU_SINGLE:-$GPU_P}" -GPU_PD="${GPU_PD:-$GPU_P}" +GPU_PD="${GPU_PD:-$GPU_D}" # Port ENCODE_PORT="${ENCODE_PORT:-19534}" @@ -138,7 +138,7 @@ run_epd_1e_1pd() { # Start encoder instance echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT" - CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + CUDA_VISIBLE_DEVICES="$GPU_E" VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5569 vllm serve "$MODEL" \ --port $ENCODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.01 \ @@ -148,18 +148,15 @@ run_epd_1e_1pd() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_producer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_producer" }' \ > $LOG_PATH/1e1pd_encoder.log 2>&1 & PIDS+=($!) # Start prefill+decode instance echo "Starting PD instance on GPU $GPU_PD, port $PREFILL_DECODE_PORT" - CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ + CUDA_VISIBLE_DEVICES="$GPU_PD" VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5579 vllm serve "$MODEL" \ --port $PREFILL_DECODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.7 \ @@ -167,11 +164,8 @@ run_epd_1e_1pd() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_consumer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_consumer" }' \ > $LOG_PATH/1e1pd_pd.log 2>&1 & PIDS+=($!) @@ -338,7 +332,7 @@ run_epd_1e_1p_1d() { # Start encoder instance echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT" - CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + CUDA_VISIBLE_DEVICES="$GPU_E" VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5589 vllm serve "$MODEL" \ --port $ENCODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.01 \ @@ -348,11 +342,8 @@ run_epd_1e_1p_1d() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_producer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_producer" }' \ > $LOG_PATH/1e1p1d_encoder.log 2>&1 & PIDS+=($!) @@ -360,6 +351,7 @@ run_epd_1e_1p_1d() { # Start prefill instance echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT" CUDA_VISIBLE_DEVICES="$GPU_P" \ + VLLM_NIXL_EC_SIDE_CHANNEL_PORT=5599 \ VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \ vllm serve "$MODEL" \ --port $PREFILL_PORT \ @@ -369,11 +361,8 @@ run_epd_1e_1p_1d() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", - "ec_role": "ec_consumer", - "ec_connector_extra_config": { - "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" - } + "ec_connector": "NixlECConnector", + "ec_role": "ec_consumer" }' \ --kv-transfer-config '{ "kv_connector": "NixlConnector", diff --git a/vllm/distributed/ec_transfer/ec_connector/factory.py b/vllm/distributed/ec_transfer/ec_connector/factory.py index bfdf51d775bd..960b3242fe84 100644 --- a/vllm/distributed/ec_transfer/ec_connector/factory.py +++ b/vllm/distributed/ec_transfer/ec_connector/factory.py @@ -86,3 +86,10 @@ def get_connector_class( "vllm.distributed.ec_transfer.ec_connector.shared_storage_connector", "ECSharedStorageConnector", ) + +ECConnectorFactory.register_connector( + "NixlECConnector", + "vllm.distributed.ec_transfer.ec_connector.nixl_connector", + "NixlECConnector", +) + diff --git a/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py b/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py new file mode 100644 index 000000000000..441666634534 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py @@ -0,0 +1,1188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import logging +import queue +import threading +import time +from collections import defaultdict +from collections.abc import Iterator +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorBase, + ECConnectorMetadata, + ECConnectorRole, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.network_utils import make_zmq_path, make_zmq_socket +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ECConnectorOutput + +if TYPE_CHECKING: + from vllm.v1.core.encoder_cache_manager import MemorySegment + from vllm.v1.request import Request + +EngineId = str +MMHash = str + +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + +# Supported xPUs and types of encoder cache transfer buffer. +_NIXL_SUPPORTED_XPUS = { + "cuda": ("cuda",), + "tpu": ("cpu",), + "xpu": ("cpu",), +} + + +class NixlECAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + dict=True, +): + """Metadata exchanged during NIXL handshake for encoder cache.""" + engine_id: str + agent_metadata: bytes + enc_base_addr: int + enc_token_bytes: int + + +@dataclass +class ECReqMeta: + """Metadata for a single encoder cache transfer request.""" + mm_hash: str + num_encoder_tokens: int + mm_base_addr: int + remote_host: str + remote_port: int + remote_engine_id: str + tp_size: int + + +@dataclass +class NixlECConnectorMetadata(ECConnectorMetadata): + """Metadata passed from scheduler to worker for encoder cache transfers.""" + + def __init__(self): + self.reqs_to_recv: dict[MMHash, ECReqMeta] = {} + self.reqs_to_send: dict[MMHash, float] = {} + + def add_recv_req( + self, + mm_hash: MMHash, + num_encoder_tokens: int, + mm_base_addr: int, + remote_host: str, + remote_port: int, + remote_engine_id: str, + tp_size: int, + ): + """Add a request to receive encoder cache from remote.""" + self.reqs_to_recv[mm_hash] = ECReqMeta( + mm_hash=mm_hash, + num_encoder_tokens=num_encoder_tokens, + mm_base_addr=mm_base_addr, + remote_host=remote_host, + remote_port=remote_port, + remote_engine_id=remote_engine_id, + tp_size=tp_size, + ) + + +class NixlECConnector(ECConnectorBase): + """NIXL-based encoder cache connector for disaggregated encoder setups.""" + + def __init__(self, vllm_config: VllmConfig, role: ECConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + assert vllm_config.ec_transfer_config is not None + assert vllm_config.ec_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.ec_transfer_config.engine_id + + if role == ECConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[NixlECConnectorScheduler] = ( + NixlECConnectorScheduler(vllm_config, self.engine_id) + ) + self.connector_worker: Optional[NixlECConnectorWorker] = None + elif role == ECConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlECConnectorWorker( + vllm_config, self.engine_id + ) + + ############################################################ + # Worker Side Methods + ############################################################ + + def register_encoder_cache( + self, + ec_cache: torch.Tensor, + ): + """Register encoder cache tensors with NIXL.""" + assert self.connector_worker is not None + # For NIXL, we register the main encoder cache tensor + # Individual mm_hash caches are handled via recv tensors + if hasattr(self.connector_worker, "encoder_cache") and \ + self.connector_worker.encoder_cache is not None: + # Already registered + return + # The encoder_cache will be registered when it's first set + # via register_encoder_cache method + self.connector_worker.register_encoder_cache(ec_cache) + + def start_load_caches( + self, encoder_cache, **kwargs + ) -> None: + """Start loading encoder caches from remote via NIXL.""" + assert self.connector_worker is not None + metadata: NixlECConnectorMetadata = self._get_connector_metadata() + + self.connector_worker.start_load_caches(encoder_cache, metadata) + + def save_caches( + self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs + ) -> None: + """Save encoder cache to remote via NIXL (no-op for NIXL, handled by request_finished).""" + # NIXL handles saving via request_finished callback + # This method is called but we don't need to do anything here + + + # pass # hero + + # add mm_base_addr to metadata by assign the data pointer add of the cache + # TODO: find a better way... this is supposed to be worker side method + + # assert self.connector_scheduler is not None + # TODO: reg cache per mm + # TODO: it doesnt work now coz self.connector_scheduler in worker role is None. they are different connector objects + assert self.connector_worker is not None + logger.debug(f"hero: save_caches!!!!") + if mm_hash in encoder_cache: + base_addr = encoder_cache[mm_hash].data_ptr() + logger.debug(f"hero: base_addr: {base_addr}") + + if self.engine_id not in self.connector_worker._ENCODER_MM_BASE_ADDRS: + self.connector_worker._ENCODER_MM_BASE_ADDRS[self.engine_id] = {} + logger.debug(f"hero: self.connector_worker._ENCODER_MM_BASE_ADDRS: {self.connector_worker._ENCODER_MM_BASE_ADDRS}") + + self.connector_worker._ENCODER_MM_BASE_ADDRS[self.engine_id][mm_hash] = base_addr + logger.debug(f"hero: self.connector_worker._ENCODER_MM_BASE_ADDRS: {self.connector_worker._ENCODER_MM_BASE_ADDRS}") + + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """Get finished receiving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def get_mm_hash_addrs(self): + """Get dict of addresses of encoder cache tensor by mm hash""" + assert self.connector_worker is not None + return self.connector_worker._ENCODER_MM_BASE_ADDRS.copy() + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def has_caches(self, request: "Request") -> list[bool]: + """Check if encoder cache exists remotely for each mm_data.""" + assert self.connector_scheduler is not None + return self.connector_scheduler.has_caches(request) + + def update_state_after_alloc(self, request: "Request", index: int) -> None: + """Update state after encoder cache allocation.""" + assert self.connector_scheduler is not None + self.connector_scheduler.update_state_after_alloc(request, index) + + def update_mm_hash_addrs_from_output(self, ec_connector_output: ECConnectorOutput) -> None: + assert self.connector_scheduler is not None + assert ec_connector_output is not None + logger.debug(f"hero: update_mm_hash_addrs_from_output in connector!!!!") + self.connector_scheduler.update_mm_hash_addrs_from_output(ec_connector_output) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> ECConnectorMetadata: + """Build connector metadata for this step.""" + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, request: "Request" + ) -> tuple[bool, dict[str, Any] | None]: + """Called when request finishes, returns transfer params if needed.""" + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request) + + +class NixlECConnectorScheduler: + """Scheduler-side implementation of NIXL EC connector.""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.engine_id: EngineId = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.side_channel_port = ( + envs.VLLM_NIXL_EC_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) + logger.info("Initializing NIXL EC Scheduler %s", engine_id) + + # Track mm_hashes that need to be loaded from remote + # mm_hash -> (request, num_encoder_tokens) + self._mm_hashes_need_recv: dict[ + MMHash, tuple["Request", int] + ] = {} + # Track mm_hashes that need to be sent (for producer role) + self._mm_hashes_need_send: dict[MMHash, float] = {} + + self.is_producer = vllm_config.ec_transfer_config.is_ec_producer + + # TODO: find a more elegant way to store & manage mm_base_addr + self._ENCODER_MM_BASE_ADDRS: dict[EngineId, dict[MMHash, int]] = {} + + def has_caches(self, request: "Request") -> list[bool]: + """Check if encoder cache exists remotely for each mm_data.""" + result = [] + + # Hero brute-force, return all true if is consumer + + # ec_transfer_params = getattr(request, "ec_transfer_params", None) + # remote_mm_segments = None + # if ec_transfer_params: + # remote_mm_segments = ec_transfer_params.get("remote_mm_segments") + + ec_transfer_params = getattr(request, "ec_transfer_params", None) + + for feature in request.mm_features: + mm_hash = feature.identifier + # Cache exists if remote_mm_segments contains this mm_hash + + # # Hero brute-force, return all true if is consumer + # has_cache = not self.is_producer + # logger.debug(f"Hero: has_cache for mm_hash {mm_hash}: {has_cache}") + # result.append(has_cache) + + if self.is_producer: + has_cache = False + else: + mm_hash_params = ( + ec_transfer_params.get(mm_hash) if ec_transfer_params else None + ) + has_cache = bool( + mm_hash_params + and mm_hash_params.get("num_encoder_tokens", 0) > 0 + and all( + p in mm_hash_params + for p in ("remote_engine_id", "remote_host", "remote_port") + ) + ) + result.append(has_cache) + + logger.debug(f"has_caches results: {result}") + return result + + def update_state_after_alloc(self, request: "Request", index: int) -> None: + """Update state after encoder cache allocation.""" + ec_transfer_params = getattr(request, "ec_transfer_params", None) + if not ec_transfer_params: + return + + mm_hash = request.mm_features[index].identifier + + # ec_transfer_params is now a dict keyed by mm_hash: {mm_hash: {...}} + # Extract params for this specific mm_hash + mm_hash_params = ec_transfer_params.get(mm_hash) + if not mm_hash_params: + logger.debug( + "No ec_transfer_params found for mm_hash %s in request %s", + mm_hash, + request.request_id, + ) + return + + if mm_hash_params.get("do_remote_encode"): + if all( + p in mm_hash_params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): + # Get num_encoder_tokens from the request + num_encoder_tokens = request.get_num_encoder_tokens(index) + self._mm_hashes_need_recv[mm_hash] = ( + request, + num_encoder_tokens, + ) + logger.debug( + "Added mm_hash %s to recv queue with num_encoder_tokens: %d", + mm_hash, + num_encoder_tokens, + ) + else: + logger.warning( + "Got invalid ECTransferParams for mm_hash %s: %s. This " + "request will not utilize EC transfer", + mm_hash, + mm_hash_params, + ) + + # Only trigger 1 EC transfer per mm_hash + mm_hash_params["do_remote_encode"] = False + + def update_mm_hash_addrs_from_output(self, ec_connector_output: ECConnectorOutput) -> None: + logger.debug(f"update_mm_hash_addrs_from_output ec_connector_output.mm_hash_addrs: {ec_connector_output.mm_hash_addrs}") + mm_hash_addrs = ec_connector_output.mm_hash_addrs + logger.debug(f"hero: mm_hash_addrs: {mm_hash_addrs}") + self._ENCODER_MM_BASE_ADDRS.update(mm_hash_addrs) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> ECConnectorMetadata: + """Build connector metadata for this step.""" + meta = NixlECConnectorMetadata() + + # Convert mm_hashes to metadata + for mm_hash, (request, num_encoder_tokens) in self._mm_hashes_need_recv.items(): + ec_transfer_params = getattr(request, "ec_transfer_params", None) + if ec_transfer_params: + # Extract params for this specific mm_hash + mm_hash_params = ec_transfer_params.get(mm_hash) + logger.debug(f"hero: mm_hash_params for {mm_hash}: {mm_hash_params}") + if mm_hash_params: + meta.add_recv_req( + mm_hash=mm_hash, + num_encoder_tokens=num_encoder_tokens, + mm_base_addr=mm_hash_params["mm_base_addr"], + remote_host=mm_hash_params["remote_host"], + remote_port=mm_hash_params["remote_port"], + remote_engine_id=mm_hash_params["remote_engine_id"], + tp_size=mm_hash_params.get("tp_size", 1), + ) + else: + logger.warning( + "No ec_transfer_params found for mm_hash %s in request %s", + mm_hash, + request.request_id, + ) + + meta.reqs_to_send = self._mm_hashes_need_send + + # Clear the lists once workers start the transfers + self._mm_hashes_need_recv.clear() + self._mm_hashes_need_send = {} + + return meta + + def request_finished( + self, request: "Request" + ) -> tuple[bool, dict[str, Any] | None]: + """Called when request finishes, returns transfer params if needed. + + For encoder instances (producers), returns ec_transfer_params keyed by mm_hash + in the format: {mm_hash: {do_remote_encode, remote_mm_segments, ...}} + """ + if not self.is_producer: + # Consumer doesn't return params + return False, None + + # Build params for all mm_hashes in this request + result_params: dict[str, dict[str, Any]] = {} + for idx, feature in enumerate(request.mm_features): + mm_hash = feature.identifier + num_encoder_tokens = request.get_num_encoder_tokens(idx) + + # Mark mm_hash to be sent asynchronously + self._mm_hashes_need_send[mm_hash] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + ) + + mm_base_addr = self._ENCODER_MM_BASE_ADDRS.get(self.engine_id, {}).get(mm_hash) + logger.debug(f"hero: mm_base_addr is {mm_base_addr} for mm_hash {mm_hash}") + logger.debug(f"hero: self._ENCODER_MM_BASE_ADDRS {self._ENCODER_MM_BASE_ADDRS}") + + # Return params keyed by mm_hash for proxy aggregation + # Format: {mm_hash: {do_remote_encode, num_encoder_tokens, remote_engine_id, ...}} + result_params[mm_hash] = { + "do_remote_encode": True, + "num_encoder_tokens": num_encoder_tokens, + "mm_base_addr": mm_base_addr, + "remote_engine_id": self.engine_id, + "remote_host": self.side_channel_host, + "remote_port": self.side_channel_port, + "tp_size": self.vllm_config.parallel_config.tensor_parallel_size, + } + + return len(result_params) > 0, result_params if result_params else None + + +class NixlECConnectorWorker: + """Worker-side implementation of NIXL EC connector.""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL EC Worker %s", engine_id) + + self.vllm_config = vllm_config + self.engine_id: EngineId = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + + # NIXL agent + self.nixl_wrapper = NixlWrapper(engine_id, None) + self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) + + # NIXL handshake port + self.side_channel_port: int = ( + envs.VLLM_NIXL_EC_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) + + # Device type and memory registration + self.device_type = current_platform.device_type + if self.device_type not in _NIXL_SUPPORTED_XPUS: + raise RuntimeError(f"{self.device_type} is not supported.") + if self.device_type == "cuda": + self.nixl_memory_type = "VRAM" + elif self.device_type in ("tpu", "xpu"): + self.nixl_memory_type = "DRAM" + else: + raise RuntimeError( + f"{self.device_type} is not supported for encoder cache transfer." + ) + self.encoder_cache_dtype: torch.dtype = vllm_config.model_config.dtype + + # Encoder cache registration + self.encoder_cache: Optional[torch.Tensor] = None + self.enc_base_addr = 0 + self.enc_token_bytes = 0 + self._registered_descs: list[Any] = [] + # TODO: find a more elegant way to store & manage mm_base_addr + self._ENCODER_MM_BASE_ADDRS: dict[EngineId, dict[MMHash, int]] = {} + + # Remote encoder cache addresses + self._remote_enc_base_addr: dict[EngineId, tuple[int, int]] = {} + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + + # Transfer tracking + self._recving_metadata: dict[MMHash, ECReqMeta] = {} + self._recving_transfers: dict[MMHash, list[tuple[int, float]]] = ( + defaultdict(list) + ) + self._reqs_to_send: dict[MMHash, float] = {} + self.mm_consumer_counts_by_req: dict[MMHash, int] = defaultdict(int) + + # Registered receive tensors and segments + self._registered_mm_descs: dict[ + str, tuple[int, Any, torch.Tensor, list["MemorySegment"]] + ] = {} + self._xfer_side_mm_handle: dict[ + tuple[str, int], tuple[str, int, int] + ] = {} + + # Encoder cache dict reference on consumer side + # Passed form the model runner + self._encoder_cache_dict: dict[str, torch.Tensor] | None = None + + # Background handshake handling + self._nixl_handshake_listener_t: threading.Thread | None = None + self._stop_event = threading.Event() + + self._handshake_initiation_executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="vllm-nixl-ec-handshake-initiator", + ) + self._ready_requests = queue.Queue[tuple[MMHash, ECReqMeta]]() + self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + self._handshake_lock = threading.RLock() + + self.is_producer = vllm_config.ec_transfer_config.is_ec_producer + + def __del__(self): + """Cleanup background threads on destruction.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t: + self._nixl_handshake_listener_t.join(timeout=0) + + def register_encoder_cache(self, encoder_cache: torch.Tensor): + """Register the main encoder cache tensor with NIXL.""" + self.encoder_cache = encoder_cache + self.enc_base_addr = encoder_cache.data_ptr() + self.enc_token_bytes = ( + encoder_cache[0].numel() * encoder_cache.element_size() + ) + enc_size_bytes = encoder_cache.numel() * encoder_cache.element_size() + + caches_data = [(self.enc_base_addr, enc_size_bytes, 0, "")] + descs = self.nixl_wrapper.get_reg_descs( + caches_data, self.nixl_memory_type + ) + + logger.debug("Registering encoder cache descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering encoder cache descs") + self._registered_descs.append(descs) + + # Start handshake listener for encoder-only instances + if self.is_producer: + metadata = NixlECAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + enc_base_addr=self.enc_base_addr, + enc_token_bytes=self.enc_token_bytes, + ) + logger.debug(f"hero: metadata at encoder: {metadata}") + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=( + metadata, + ready_event, + self._stop_event, + self.side_channel_port, + ), + daemon=True, + name="nixl_ec_handshake_listener", + ) + self._nixl_handshake_listener_t.start() + ready_event.wait() + + @staticmethod + def _nixl_handshake_listener( + metadata: NixlECAgentMetadata, + ready_event: threading.Event, + stop_event: threading.Event, + port: int, + ): + """Background thread for handling NIXL handshake requests.""" + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlECAgentMetadata: %s bytes", size_in_bytes) + + # host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + # path = make_zmq_path("tcp", host, base_port + tp_rank) + # logger.debug("Starting EC listener on path: %s", path) + # with zmq_ctx(zmq.ROUTER, path) as sock: + # ready_event.set() + # while True: + # identity, _, msg = sock.recv_multipart() + # if msg != GET_META_MSG: + # logger.warning( + # "EC connection listener got unexpected message %s", msg + # ) + # sock.send_multipart((identity, b"", encoded_data)) + + # hero (from kv nixl) + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, port) + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + sock.setsockopt(zmq.RCVTIMEO, 1000) + ready_event.set() + while True: + try: + identity, _, msg = sock.recv_multipart() + except zmq.Again: + if stop_event.is_set(): + break + continue + # Decode the message which contains (GET_META_MSG, rank) + msg, target_tp_rank = msgspec.msgpack.decode(msg) + logger.debug( + "Received message for tp rank %s; msg: %s", + target_tp_rank, + msg, + ) + if msg != GET_META_MSG: + logger.warning("Connection listener got unexpected message %s", msg) + logger.debug(f"hero: encoded_data: {encoded_data}") + sock.send_multipart((identity, b"", encoded_data)) + + def _nixl_handshake( + self, + host: str, + port: int, + remote_tp_size: int, + expected_engine_id: str, + ) -> dict[int, str]: + """Do a NIXL handshake with a remote encoder instance.""" + start_time = time.perf_counter() + + # tp_ratio = self._tp_size[self.engine_id] // remote_tp_size + # p_remote_rank = self.tp_rank // tp_ratio + # path = make_zmq_path("tcp", host, port + p_remote_rank) + # logger.debug( + # "Querying EC metadata on path: %s at remote rank %s", + # path, + # p_remote_rank, + # ) + + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + + path = make_zmq_path("tcp", host, port) + logger.debug( + "Querying EC metadata on path: %s", + path, + ) + + with zmq_ctx(zmq.REQ, path) as sock: + msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) + sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds + sock.send(msg) + logger.debug("hero: aaaaaaaaaaaaaaaaaaaaaa") + metadata_bytes = sock.recv() + logger.debug("hero: bbbbbbbbbbbbbbbbbbbbbbbb") + decoder = msgspec.msgpack.Decoder(NixlECAgentMetadata) + logger.debug("hero: ccccccccccccccccccccccc") + metadata = decoder.decode(metadata_bytes) + logger.debug("hero: ddddddddddddddddddddddddddd") + logger.debug(f"hero: metadata: {metadata}") + got_metadata_time = time.perf_counter() + logger.debug( + "NIXL EC handshake: get metadata took: %s", + got_metadata_time - start_time, + ) + + if metadata.engine_id != expected_engine_id: + raise RuntimeError( + f"Remote NIXL EC agent engine ID mismatch. " + f"Expected {expected_engine_id}, received {metadata.engine_id}." + ) + + remote_agent_name = self.add_remote_agent( + metadata, p_remote_rank, remote_tp_size + ) + setup_agent_time = time.perf_counter() + logger.debug( + "NIXL EC handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) + + return {p_remote_rank: remote_agent_name} + + def add_remote_agent( + self, + nixl_agent_meta: NixlECAgentMetadata, + remote_tp_rank: int = 0, + remote_tp_size: int = 1, + ) -> str: + """Add remote NIXL agent and prepare descriptors for reading encoder cache.""" + engine_id = nixl_agent_meta.engine_id + if remote_tp_rank in self._remote_agents.get(engine_id, {}): + return self._remote_agents[engine_id][remote_tp_rank] + + if engine_id not in self._tp_size: + self._tp_size[engine_id] = remote_tp_size + else: + assert self._tp_size[engine_id] == remote_tp_size + + remote_agent_name = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata + ) + + self._remote_enc_base_addr[engine_id] = ( + nixl_agent_meta.enc_base_addr, + nixl_agent_meta.enc_token_bytes, + ) + logger.debug(f"hero: add_remote_agent end; remote_agent_name: {remote_agent_name}") + logger.debug(f"hero: added {nixl_agent_meta.enc_base_addr}, {nixl_agent_meta.enc_token_bytes}") + return remote_agent_name + + def _background_nixl_handshake( + self, mm_hash: str, remote_engine_id: EngineId, meta: ECReqMeta + ): + """Do NIXL handshake in background and add to _ready_requests when done.""" + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, + meta.remote_host, + meta.remote_port, + meta.tp_size, + remote_engine_id, + ) + logger.debug(f"hero: after submit") + self._handshake_futures[remote_engine_id] = fut + logger.debug(f"hero: after fut") + + def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): + with self._handshake_lock: + logger.debug(f"hero: del") + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("EC handshake with %s failed", eid) + + logger.debug(f"hero: add_done_callback") + fut.add_done_callback(done_callback) + + # def request_ready(_f: Future[Any], entry=(mm_hash, meta)): + # self._ready_requests.put(entry) + # hero + + # check handshake success before proceeding with request + def request_ready(f: Future[Any], entry=(mm_hash, meta)): + try: + # check if handshake succeeded + f.result() + self._ready_requests.put(entry) + logger.debug(f"hero: request is ready! entry: {entry}; f.result(): {f.result()}") + logger.debug(f"hero self._ready_requests.empty() after request ready: {self._ready_requests.empty()}") + except Exception: + # handshake failed + logger.exception( + "Handshake failed for mm_hash %s", mm_hash + ) + + fut.add_done_callback(request_ready) + + def register_encoder_recv_tensor( + self, + mm_hash: str, + recv_tensor: torch.Tensor, + local_segments: list["MemorySegment"] | None = None, + ): + """Register a receive tensor for encoder cache transfer.""" + base_addr = recv_tensor.data_ptr() + size_bytes = recv_tensor.numel() * recv_tensor.element_size() + + self.device_id = max(recv_tensor.get_device(), 0) + caches_data = [(base_addr, size_bytes, self.device_id, "")] + + descs = self.nixl_wrapper.get_reg_descs( + caches_data, self.nixl_memory_type + ) + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_mm_descs[mm_hash] = ( + base_addr, + descs, + recv_tensor, + local_segments, + ) + logger.debug(f"hero: self._registered_mm_descs[mm_hash]s for mm_hash {mm_hash} with base_addr {base_addr}") + + + def start_load_caches( + self, + encoder_cache, + metadata: NixlECConnectorMetadata, + ): + """Start loading encoder caches from remote via NIXL.""" + + # Get the metadata + assert isinstance(metadata, NixlECConnectorMetadata) + assert encoder_cache is not None + if metadata is None: + logger.warning( + ( + "In connector.start_load_caches, ", + "but the connector metadata is None", + ) + ) + return + logger.debug(f"hero: start_load_caches: {metadata.reqs_to_recv.items()}") + + # Reference the encoder_cache + self._encoder_cache_dict = encoder_cache + + # hero + # # First, register all receive tensors from encoder_cache + # logger.debug(f"encoder_cache: {encoder_cache}") + # for mm_hash, meta in metadata.reqs_to_recv.items(): + # if mm_hash in encoder_cache: + # recv_tensor = encoder_cache[mm_hash] + # if mm_hash not in self._registered_descs: + # logger.debug( + # "Registering receive tensor for mm_hash %s, shape: %s", + # mm_hash, + # recv_tensor.shape, + # ) + # self.register_encoder_recv_tensor(mm_hash, recv_tensor) + # else: + # logger.warning( + # "mm_hash %s not found in encoder_cache, cannot register receive tensor!", + # mm_hash, + # ) + + logger.debug(f"hero: metadata.reqs_to_recv.items(): {metadata.reqs_to_recv.items()}") + for mm_hash, meta in metadata.reqs_to_recv.items(): + logger.debug(f"hero: {(mm_hash, meta)}") + remote_engine_id = meta.remote_engine_id + self._recving_metadata[mm_hash] = meta + logger.debug( + "start_load_caches for mm_hash %s from remote engine %s", + mm_hash, + remote_engine_id, + ) + + if remote_engine_id not in self._remote_agents: + # Initiate handshake with remote engine + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + self._background_nixl_handshake( + mm_hash, remote_engine_id, meta + ) + logger.debug(f"hero: _background_nixl_handshake after") + # time.sleep(10) + # logger.debug(f"hero: wait 10 for request ready before read mm and during handshake") + continue + + logger.debug(f"hero: before read mm") + # time.sleep(2) + # logger.debug(f"hero: wait 2 for request ready before read mm") + # Handshake completed, start async read transfer + self._read_mm_segments(mm_hash, meta) + + if metadata.reqs_to_recv: # if not empty + for mm_hash, meta in metadata.reqs_to_recv.items(): + remote_engine_id = meta.remote_engine_id + if remote_engine_id not in self._remote_agents: + logger.debug(f"hero: {remote_engine_id} not in {self._remote_agents}") + logger.debug(f"hero: start wait 2 for request ready before read mm") + time.sleep(2) + logger.debug(f"hero: waited 2 for request ready before read mm") + + # time.sleep(2) + # logger.debug(f"hero: wait 2 for request ready before read mm") + # Start transfers for requests whose handshakes have finished + logger.debug(f"hero: self._ready_requests.empty(): {self._ready_requests.empty()}") + # logger.debug(f"hero: self._ready_requests.get_nowait(): {self._ready_requests.get_nowait()}") + # logger.debug(f"hero: self._ready_requests.empty(): {self._ready_requests.empty()}") + + while not self._ready_requests.empty(): + logger.debug(f"while not self._ready_requests.empty():") + self._read_mm_segments(*self._ready_requests.get_nowait()) + + # Add to requests waiting to be read + self._reqs_to_send.update(metadata.reqs_to_send) + + def _read_mm_segments(self, mm_hash: str, meta: ECReqMeta): + """Read encoder cache from remote via NIXL. + + Transfers the entire encoder cache tensor for the given mm_hash. + """ + logger.debug(f"hero: _read_mm_segments") + logger.debug(f"hero: self._remote_agents: {self._remote_agents}") + remote_engine_id = meta.remote_engine_id + num_encoder_tokens = meta.num_encoder_tokens + + if num_encoder_tokens == 0: + # No tokens to read, just send notification + tp_ratio = ( + self._tp_size[self.engine_id] // self._tp_size[remote_engine_id] + ) + notif_id = f"{mm_hash}:{tp_ratio}$1".encode() + agent_name = self._remote_agents[remote_engine_id][0] + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + return + + if mm_hash not in self._registered_mm_descs: + if remote_engine_id not in self._remote_enc_base_addr: + logger.error( + "Remote encoder base addr for engine %s not found when " + "starting transfer for mm_hash %s", + remote_engine_id, + mm_hash + ) + return + # logger.warning( + # "mm_hash %s not registered for receive, skipping transfer", + # mm_hash, + # ) + # return + # hero + base_addr, token_bytes = self._remote_enc_base_addr[remote_engine_id] + + # Derive hidden size from bytes-per-token and dtype element size. + elem_size = torch.tensor([], dtype=self.encoder_cache_dtype).element_size() + assert token_bytes % elem_size == 0, ( + f"enc_token_bytes {token_bytes} not divisible by element size " + f"{elem_size} for dtype {self.encoder_cache_dtype}" + ) + hidden_size = token_bytes // elem_size + + # Allocate local receive tensor and expose it to the encoder_cache dict. + recv_tensor = torch.empty( + (num_encoder_tokens, hidden_size), + device=self.device_type, + dtype=self.encoder_cache_dtype, + ) + + assert self._encoder_cache_dict is not None + logger.debug(f"hero: self._encoder_cache_dict: {self._encoder_cache_dict}") + self._encoder_cache_dict[mm_hash] = recv_tensor + + logger.debug(f"hero: self._encoder_cache_dict after recv_tensor: {self._encoder_cache_dict}") + + logger.debug(f"hero: size: {recv_tensor.size(), self._encoder_cache_dict[mm_hash].size()} / recv_tensor for {mm_hash}: {self._encoder_cache_dict[mm_hash]}") + + logger.debug( + "Allocating receive tensor for mm_hash %s with shape %s " + "(num_tokens=%d, hidden_size=%d)", + mm_hash, + recv_tensor.shape, + num_encoder_tokens, + hidden_size, + ) + + self.register_encoder_recv_tensor(mm_hash, recv_tensor) + + base_addr, token_bytes = self._remote_enc_base_addr[remote_engine_id] + local_base_addr, _, recv_tensor, _ = self._registered_mm_descs[mm_hash] + + tp_ratio = ( + self._tp_size[self.engine_id] // self._tp_size[remote_engine_id] + ) + notif_id = f"{mm_hash}:{tp_ratio}$1".encode() + agent_name = self._remote_agents[remote_engine_id][0] + + # Transfer the whole tensor: offset 0, length num_encoder_tokens + seg_bytes = num_encoder_tokens * token_bytes + remote_addr = meta.mm_base_addr # base_addr # Start from base address (offset 0) + local_addr = local_base_addr # Start from local base address (offset 0) + + remote_segments = [(remote_addr, seg_bytes, 0)] + local_seg_addrs = [(local_addr, seg_bytes, 0)] + idx = [0] # Single segment + + logger.debug(f"hero: remote_addr: {remote_addr}; local_addr: {local_addr}; num_encoder_tokens: {num_encoder_tokens}; mm_hash: {mm_hash}") + + src_descs = self.nixl_wrapper.get_xfer_descs( + local_seg_addrs, self.nixl_memory_type + ) + src_xfer_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", src_descs + ) + + dst_descs = self.nixl_wrapper.get_xfer_descs( + remote_segments, self.nixl_memory_type + ) + dst_xfer_handle = self.nixl_wrapper.prep_xfer_dlist( + agent_name, dst_descs + ) + + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + src_xfer_handle, + idx, + dst_xfer_handle, + idx, + notif_msg=notif_id, + ) + + self.nixl_wrapper.transfer(handle) + + self._xfer_side_mm_handle[(mm_hash, handle)] = ( + mm_hash, + src_xfer_handle, + dst_xfer_handle, + ) + + self._recving_transfers[mm_hash].append((handle, time.perf_counter())) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """Get finished receiving and sending requests.""" + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + + # Copy received data to encoder cache if needed + for mm_hash in done_recving: + logger.debug(f"mm_hash {mm_hash} done recving") + if mm_hash in self._recving_metadata: + meta = self._recving_metadata.pop(mm_hash) + self._copy_recv_to_encoder_cache(mm_hash) + + # Handle timeout + now = time.perf_counter() + expired_mm_hashes = [] + for mm_hash, expires in list(self._reqs_to_send.items()): + if now >= expires: + count = self.mm_consumer_counts_by_req.pop(mm_hash, 0) + logger.warning( + "Releasing expired EC cache for mm_hash %s which was " + "retrieved by %d consumer(s) within %d seconds.", + mm_hash, + count, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) + expired_mm_hashes.append(mm_hash) + done_sending.add(mm_hash) + + for mm_hash in expired_mm_hashes: + del self._reqs_to_send[mm_hash] + + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", + self.tp_rank, + len(done_sending), + len(done_recving), + ) + + return done_sending if done_sending else None, ( + done_recving if done_recving else None + ) + + def _copy_recv_to_encoder_cache(self, mm_hash: str): + """Mark received encoder cache tensor as ready.""" + logger.debug(f"hero: _copy_recv_to_encoder_cache for mm_hash {mm_hash}") + if mm_hash not in self._registered_mm_descs: + # time.sleep(3) # hero + if mm_hash not in self._registered_mm_descs: + # logger.debug(f"hero: after 3s mm_hash {mm_hash} still not in {self._registered_mm_descs}") + return + + _, _, recv_tensor, _ = self._registered_mm_descs[mm_hash] + + # # Copy the whole tensor to encoder_cache[mm_hash] + # # This matches the shared_storage_connector pattern + # if self.encoder_cache is not None: + # self.encoder_cache[mm_hash] = recv_tensor.clone() + # logger.debug( + # "Copied received encoder cache for mm_hash %s, shape: %s", + # mm_hash, + # recv_tensor.shape, + # ) + + # hero: + if self._encoder_cache_dict is not None: + self._encoder_cache_dict[mm_hash] = recv_tensor.clone() + logger.debug( + "Copied received encoder cache for mm_hash %s, shape: %s", + mm_hash, + recv_tensor.shape, + ) + logger.debug(f"hero: mm_hash {mm_hash} cloned tensor: {recv_tensor}") + + logger.debug( + "Encoder cache transfer completed for mm_hash %s, shape: %s", + mm_hash, + recv_tensor.shape, + ) + + def _get_new_notifs(self) -> set[str]: + """Get mm_hashes which got a remote transfer notification.""" + notified_mm_hashes: set[str] = set() + for notifs in self.nixl_wrapper.get_new_notifs().values(): + for notif in notifs: + notif_decode = notif.decode("utf-8") + mm_notif, mm_ratio = notif_decode.rsplit("$", 1) + mm_hash, tp_ratio = mm_notif.rsplit(":", 1) + if mm_hash not in self._reqs_to_send: + logger.error( + "Potentially invalid EC cache for unrecognized " + "mm_hash %s was retrieved by a consumer. " + "It may have expired.", + mm_hash, + ) + continue + + self.mm_consumer_counts_by_req[mm_hash] += 1 + if self.mm_consumer_counts_by_req[mm_hash] == int(tp_ratio) + int( + mm_ratio + ): + notified_mm_hashes.add(mm_hash) + del self.mm_consumer_counts_by_req[mm_hash] + del self._reqs_to_send[mm_hash] + logger.debug(f"hero: del mm_hash {mm_hash} from self._reqs_to_send") + + return notified_mm_hashes + + def _pop_done_transfers( + self, transfers: dict[str, list[tuple[int, float]]] + ) -> set[str]: + """Pop completed transfers by checking for DONE state.""" + done_mm_hashes: set[str] = set() + for mm_hash, handles in list(transfers.items()): + in_progress = False + for handle, _xfer_stime in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + self.nixl_wrapper.release_xfer_handle(handle) + self._release_mm_handle(mm_hash, handle) + elif xfer_state == "PROC": + in_progress = True + continue + else: + raise RuntimeError( + f"EC transfer failed with state {xfer_state}" + ) + if not in_progress: + done_mm_hashes.add(mm_hash) + del transfers[mm_hash] + return done_mm_hashes + + def _release_mm_handle(self, mm_hash: str, handle: int): + """Release NIXL handles and deregister memory for a completed transfer.""" + if (mm_hash, handle) not in self._xfer_side_mm_handle: + return + + _, src_xfer_handle, dst_xfer_handle = self._xfer_side_mm_handle[ + (mm_hash, handle) + ] + _, mm_descs, recv_tensor, _ = self._registered_mm_descs[mm_hash] + + self.nixl_wrapper.release_dlist_handle(src_xfer_handle) + self.nixl_wrapper.release_dlist_handle(dst_xfer_handle) + self.nixl_wrapper.deregister_memory(mm_descs) + + # Copy to encoder cache if needed + if self.encoder_cache is not None: + self._copy_recv_to_encoder_cache(mm_hash) + + del self._xfer_side_mm_handle[(mm_hash, handle)] + del self._registered_mm_descs[mm_hash] + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket.""" + + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 688ea9697d9d..c5e83287893f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -711,6 +711,10 @@ class ChatCompletionRequest(OpenAIBaseModel): default=None, description="KVTransfer parameters used for disaggregated serving.", ) + ec_transfer_params: dict[str, Any] | None = Field( + default=None, + description="ECTransfer parameters used for disaggregated encoder serving.", + ) vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, @@ -812,6 +816,9 @@ def to_sampling_params( if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args extra_args["kv_transfer_params"] = self.kv_transfer_params + if self.ec_transfer_params: + # Pass in ec_transfer_params via extra_args + extra_args["ec_transfer_params"] = self.ec_transfer_params return SamplingParams.from_optional( n=self.n, presence_penalty=self.presence_penalty, @@ -1131,6 +1138,11 @@ class CompletionRequest(OpenAIBaseModel): default=None, description="KVTransfer parameters used for disaggregated serving.", ) + ec_transfer_params: dict[str, Any] | None = Field( + default=None, + description="ECTransfer parameters used for disaggregated encoder serving.", + ) + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, @@ -1240,6 +1252,9 @@ def to_sampling_params( if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args extra_args["kv_transfer_params"] = self.kv_transfer_params + if self.ec_transfer_params: + # Pass in ec_transfer_params via extra_args + extra_args["ec_transfer_params"] = self.ec_transfer_params return SamplingParams.from_optional( n=self.n, presence_penalty=self.presence_penalty, @@ -3136,6 +3151,10 @@ class GenerateRequest(BaseModel): default=None, description="KVTransfer parameters used for disaggregated serving.", ) + ec_transfer_params: dict[str, Any] | None = Field( + default=None, + description="ECTransfer parameters used for disaggregated encoder serving.", + ) class GenerateResponseChoice(BaseModel): @@ -3163,3 +3182,7 @@ class GenerateResponse(BaseModel): default=None, description="KVTransfer parameters used for disaggregated serving.", ) + ec_transfer_params: dict[str, Any] | None = Field( + default=None, + description="ECTransfer parameters used for disaggregated encoder serving.", + ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9a7051e0920a..7ba3bb137fea 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1582,6 +1582,7 @@ async def chat_completion_full_generator( final_res.prompt_token_ids if request.return_token_ids else None ), kv_transfer_params=final_res.kv_transfer_params, + ec_transfer_params=final_res.ec_transfer_params, ) # Log complete response if output logging is enabled diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9681aa8c71e6..b61c968de440 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -606,8 +606,11 @@ def request_output_to_completion_response( ) request_metadata.final_usage_info = usage + kv_transfer_params = None + ec_transfer_params = None if final_res_batch: kv_transfer_params = final_res_batch[0].kv_transfer_params + ec_transfer_params = final_res_batch[0].ec_transfer_params return CompletionResponse( id=request_id, created=created_time, @@ -615,6 +618,7 @@ def request_output_to_completion_response( choices=choices, usage=usage, kv_transfer_params=kv_transfer_params, + ec_transfer_params=ec_transfer_params, ) def _create_completion_logprobs( diff --git a/vllm/entrypoints/openai/serving_tokens.py b/vllm/entrypoints/openai/serving_tokens.py index 69a526b9b70d..388f09759db6 100644 --- a/vllm/entrypoints/openai/serving_tokens.py +++ b/vllm/entrypoints/openai/serving_tokens.py @@ -210,6 +210,7 @@ async def serve_tokens_full_generator( usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), kv_transfer_params=final_res.kv_transfer_params, + ec_transfer_params=final_res.ec_transfer_params, ) # Log complete response if output logging is enabled diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d398..69d60c86d9ff 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -171,6 +171,7 @@ VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 + VLLM_NIXL_EC_SIDE_CHANNEL_PORT: int = 5700 VLLM_ALL2ALL_BACKEND: Literal[ "naive", "pplx", @@ -1234,6 +1235,9 @@ def get_vllm_port() -> int | None: "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") ), + "VLLM_NIXL_EC_SIDE_CHANNEL_PORT": lambda: int( + os.getenv("VLLM_NIXL_EC_SIDE_CHANNEL_PORT", "5700") + ), # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using broadcasts diff --git a/vllm/outputs.py b/vllm/outputs.py index cdfe06f1c7fa..2cdfab8cc99f 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -103,6 +103,7 @@ class RequestOutput: None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. kv_transfer_params: The params for remote K/V transfer. + ec_transfer_params: The params for remote encoder cache transfer. """ def __init__( @@ -121,6 +122,7 @@ def __init__( *, multi_modal_placeholders: MultiModalPlaceholderDict | None = None, kv_transfer_params: dict[str, Any] | None = None, + ec_transfer_params: dict[str, Any] | None = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -142,12 +144,14 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens self.kv_transfer_params = kv_transfer_params + self.ec_transfer_params = ec_transfer_params def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished self.kv_transfer_params = next_output.kv_transfer_params + self.ec_transfer_params = next_output.ec_transfer_params for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0304a8ec48bf..800585aa1c54 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -992,6 +992,7 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + ec_connector_output = model_runner_output.ec_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None @@ -1012,6 +1013,11 @@ def update_from_output( kv_connector_output.invalid_block_ids ) + logger.debug(f"hero: self.ec_connector: {self.ec_connector} / ec_connector_output: {ec_connector_output}") + if self.ec_connector.is_producer and ec_connector_output: + logger.debug(f"hero: update_mm_hash_addrs_from_output!") + self.ec_connector.update_mm_hash_addrs_from_output(ec_connector_output) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best # to avoid expensive operations inside the loop. @@ -1077,11 +1083,13 @@ def update_from_output( stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: - kv_transfer_params = self._free_request(request) + kv_transfer_params, ec_transfer_params = self._free_request(request) if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) else: stopped_preempted_reqs.add(request) + else: + ec_transfer_params = None # Extract sample logprobs if needed. if ( @@ -1102,7 +1110,7 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None or kv_transfer_params: + if new_token_ids or pooler_output is not None or kv_transfer_params or ec_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -1115,6 +1123,7 @@ def update_from_output( stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + ec_transfer_params=ec_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, num_nans_in_logits=request.num_nans_in_logits, @@ -1305,10 +1314,14 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> dict[str, Any] | None: + def _free_request(self, request: Request) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) + + # Handle EC connector request_finished + _, ec_xfer_params = self._ec_connector_finished(request) + self.encoder_cache_manager.free(request) request_id = request.request_id self.finished_req_ids.add(request_id) @@ -1318,7 +1331,7 @@ def _free_request(self, request: Request) -> dict[str, Any] | None: if not delay_free_blocks: self._free_blocks(request) - return kv_xfer_params + return kv_xfer_params, ec_xfer_params def _free_blocks(self, request: Request): assert request.is_finished() @@ -1423,6 +1436,20 @@ def _connector_finished( return self.connector.request_finished_all_groups(request, block_ids) + def _ec_connector_finished( + self, request: Request + ) -> tuple[bool, dict[str, Any] | None]: + """ + Invoke the KV connector request_finished() method if applicable. + + Returns optional kv transfer parameters to be included with the + request outputs. + """ + if self.ec_connector is None: + return False, None + + return self.ec_connector.request_finished(request) + def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ KV Connector: check if the request_id is finished_recving. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2aae77108d..76f7ba47c865 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -126,6 +126,7 @@ class EngineCoreOutput( stop_reason: int | str | None = None events: list[EngineCoreEvent] | None = None kv_transfer_params: dict[str, Any] | None = None + ec_transfer_params: dict[str, Any] | None = None trace_headers: Mapping[str, str] | None = None # The number of tokens with prefix cache hits. diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 0453c4a77f0c..8a17f8522167 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -204,6 +204,7 @@ def make_request_output( finish_reason: FinishReason | None, stop_reason: int | str | None, kv_transfer_params: dict[str, Any] | None = None, + ec_transfer_params: dict[str, Any] | None = None, ) -> RequestOutput | PoolingRequestOutput | None: finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -253,7 +254,7 @@ def make_request_output( return None return self._new_request_output( - request_id, outputs, finished, kv_transfer_params + request_id, outputs, finished, kv_transfer_params, ec_transfer_params ) def _new_request_output( @@ -262,6 +263,7 @@ def _new_request_output( outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, kv_transfer_params: dict[str, Any] | None = None, + ec_transfer_params: dict[str, Any] | None = None, ) -> RequestOutput | PoolingRequestOutput: first_output = outputs[0] if isinstance(first_output, PoolingOutput): @@ -295,6 +297,7 @@ def _new_request_output( outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, + ec_transfer_params=ec_transfer_params, num_cached_tokens=self.num_cached_tokens, metrics=self.stats, ) @@ -482,6 +485,7 @@ def process_outputs( finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params + ec_transfer_params = engine_core_output.ec_transfer_params req_state.num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False @@ -507,6 +511,7 @@ def process_outputs( finish_reason, stop_reason, kv_transfer_params, + ec_transfer_params, ): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e32d5bb608b1..17eb4592ca8a 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -143,6 +143,7 @@ class ECConnectorOutput: # [mm_hash] finished_sending: set[str] | None = None finished_recving: set[str] | None = None + mm_hash_addrs: dict[str, dict[str, int]] | None = None # ModelRunnerOutput is serialized and sent to the scheduler process. diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 366cdadf5a58..175f566b7aae 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -64,6 +64,15 @@ def __init__( # P/D: Connector-specific KV transfer parameters. self.kv_transfer_params: dict[str, Any] | None = None + # E/P: Connector-specific EC transfer parameters. + # Format: dict[mm_hash, dict] where each dict contains: + # - do_remote_encode: bool + # - remote_engine_id: str + # - remote_host: str + # - remote_port: int + # - remote_mm_segments: dict[mm_hash, list[tuple[int, int]]] + # - tp_size: int + self.ec_transfer_params: dict[str, Any] | None = None if pooling_params is not None: # Pooling models. @@ -79,6 +88,9 @@ def __init__( self.kv_transfer_params = sampling_params.extra_args.get( "kv_transfer_params" ) + self.ec_transfer_params = sampling_params.extra_args.get( + "ec_transfer_params" + ) else: raise ValueError("sampling_params and pooling_params can't both be unset") diff --git a/vllm/v1/worker/ec_connector_model_runner_mixin.py b/vllm/v1/worker/ec_connector_model_runner_mixin.py index 00bc909df297..6437cb8fd7a2 100644 --- a/vllm/v1/worker/ec_connector_model_runner_mixin.py +++ b/vllm/v1/worker/ec_connector_model_runner_mixin.py @@ -83,5 +83,11 @@ def _get_ec_connector_output( output.finished_sending, output.finished_recving = ( ec_connector.get_finished(scheduler_output.finished_req_ids) ) + logger.debug(f"hero: finally@!") + logger.debug(f"hero: output.finished_sending, output.finished_recving: {output.finished_sending, output.finished_recving}") + if ec_connector.is_producer: + logger.debug(f"hero: get_mm_hash_addrs:") + output.mm_hash_addrs = ec_connector.get_mm_hash_addrs() + logger.debug(f"hero: output.mm_hash_addrs: {output.mm_hash_addrs}") ec_connector.clear_connector_metadata() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ae4eb48acf2..e92da4e8bb9f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -258,6 +258,25 @@ class ExecuteModelState(NamedTuple): aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None +class CustomTensorPool: + def __init__(self, max_num_tokens, hidden_size, dtype, device): + self.pool = torch.empty((114688, hidden_size), dtype=dtype, device=device) + # self.total_size = max_num_tokens + self.total_size = 114688 + self.offset = 0 + self.allocated_blocks = {} # Track allocations by mm_hash + + def allocate(self, size, mm_hash=None): + if self.offset + size > self.total_size: + # raise RuntimeError("Pool exhausted") + logger.debug(f"hero: Pool exhausted; reset offset to 0") + self.offset = 0 + tensor = self.pool[self.offset:self.offset + size] + if mm_hash: + self.allocated_blocks[mm_hash] = (self.offset, size) + self.offset += size + return tensor + class GPUModelRunner( LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin @@ -366,6 +385,69 @@ def __init__( # mm_hash -> encoder_output self.encoder_cache: dict[str, torch.Tensor] = {} + ## nixl cache! + # ================================================================== + self.ec_main_cache: torch.Tensor | None = None + if has_ec_transfer() and get_ec_transfer().is_producer: + # For encoder, self.max_num_tokens => max number of encoder input + if self.max_num_tokens > 0: + self.ec_main_cache = CustomTensorPool( + max_num_tokens=self.max_num_tokens, + hidden_size=self.hidden_size, + dtype=self.dtype, + device=self.device, + ) + logger.info( + "Registering EC main cache tensor for Nixl EC producer: " + "max_num_tokens=%s, shape=%s, dtype=%s", + self.max_num_tokens, + self.ec_main_cache.pool.shape, + self.ec_main_cache.pool.dtype, + ) + get_ec_transfer().register_encoder_cache(self.ec_main_cache.pool) + else: + logger.warning( + "EC transfer producer enabled but max_encoder_len == 0; " + "skipping EC main cache registration." + ) + + # ================================================================== + # Nixl ECConnector: register main encoder cache pool on producers. + # This provides a stable base address and per-token stride that the + # Nixl ECConnector uses for metadata in its handshake. + # # ================================================================== + # self.ec_main_cache: torch.Tensor | None = None + # if has_ec_transfer() and get_ec_transfer().is_producer: + # # For encoder, self.max_num_tokens => max number of encoder input + # if self.max_num_tokens > 0: + # self.ec_main_cache = torch.empty( + # (self.max_num_tokens, self.hidden_size), + # device=self.device, + # dtype=self.dtype, + # ) + # logger.info( + # "Registering EC main cache tensor for Nixl EC producer: " + # "shape=%s, dtype=%s", + # self.ec_main_cache.shape, + # self.ec_main_cache.dtype, + # ) + # get_ec_transfer().register_encoder_cache(self.ec_main_cache) + + # # # hero; why not just use self.encoder_cache + # # logger.info( + # # "Registering self.encoder_cache tensor for Nixl EC producer: " + # # "shape=%s, dtype=%s", + # # self.encoder_cache.shape, + # # self.encoder_cache.dtype, + # # ) + # # get_ec_transfer().register_encoder_cache(self.self.encoder_cache) + # else: + # logger.warning( + # "EC transfer producer enabled but max_encoder_len == 0; " + # "skipping EC main cache registration." + # ) + + self.use_aux_hidden_state_outputs = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on @@ -2105,11 +2187,30 @@ def _execute_mm_encoder( # Cache the encoder outputs by mm_hash for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + # Calculate size needed (in elements, not bytes) + tensor_size = pos_info.length + logger.debug(f"hero: tensor_size: {tensor_size}") + + # Allocate from pool - this creates a view, no copy! + pooled_tensor = self.ec_main_cache.allocate( + tensor_size, + mm_hash=mm_hash + ).view(output.shape) + + # Copy encoder output to pooled location + pooled_tensor.copy_(output) + + # self.encoder_cache[mm_hash] = scatter_mm_placeholders( + # output, + # is_embed=pos_info.is_embed, + # ) + self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, + pooled_tensor, is_embed=pos_info.is_embed, ) logger.debug("Finish execute for mm hash %s", mm_hash) + logger.debug(f"hero: size: {self.encoder_cache[mm_hash].size()} / self.encoder_cache[mm_hash] for {mm_hash}: {self.encoder_cache[mm_hash]}") self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) return encoder_outputs @@ -2141,6 +2242,9 @@ def _gather_mm_embeddings( start_pos = pos_info.offset num_encoder_tokens = pos_info.length + # hero: + logger.debug(f"hero: start_pos: {start_pos}; num_computed_tokens: {num_computed_tokens}; num_scheduled_tokens: {num_scheduled_tokens}; num_encoder_tokens: {num_encoder_tokens}") + # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, # num_computed_tokens + num_scheduled_tokens) and @@ -2153,6 +2257,10 @@ def _gather_mm_embeddings( # in the decoder's KV cache. continue + # logger.debug(f"hero: sleep 3s to try wait for the nixl thing?") + # time.sleep(3) + # logger.debug(f"hero: finish sleeping 3s") + start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, @@ -2813,7 +2921,11 @@ def execute_model( encoder_cache=self.encoder_cache, ) as ec_connector_output: self._execute_mm_encoder(scheduler_output) - return make_empty_encoder_model_runner_output(scheduler_output) + # return make_empty_encoder_model_runner_output(scheduler_output) + encoder_model_runner_output = make_empty_encoder_model_runner_output(scheduler_output) + encoder_model_runner_output.ec_connector_output = ec_connector_output + logger.debug(f"hero: modified encoder_model_runner_output: {encoder_model_runner_output}") + return encoder_model_runner_output if not num_scheduled_tokens: if ( @@ -2922,6 +3034,8 @@ def execute_model( scheduler_output, num_tokens_padded, intermediate_tensors ) + logger.debug(f"hero: ec_connector_output from _preprocess: {ec_connector_output}") + # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture. From e31b4b5b065516d0c61f2f7701a2a50cbc592f8d Mon Sep 17 00:00:00 2001 From: herotai214 Date: Wed, 17 Dec 2025 08:15:51 +0000 Subject: [PATCH 2/2] [WIP][Feature] Nixl ECConnector works properly Signed-off-by: herotai214 --- .../ec_transfer/ec_connector/base.py | 5 +- .../ec_connector/nixl_connector.py | 99 +++++++++++-------- .../ec_connector/shared_storage_connector.py | 15 +++ vllm/v1/core/encoder_cache_manager.py | 3 + vllm/v1/worker/gpu_model_runner.py | 28 +++++- 5 files changed, 103 insertions(+), 47 deletions(-) diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py index 2b7b14d89b8a..106276ff561a 100644 --- a/vllm/distributed/ec_transfer/ec_connector/base.py +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -113,15 +113,16 @@ def _get_connector_metadata(self) -> ECConnectorMetadata: def register_caches( self, - ec_caches: dict[str, torch.Tensor], + ec_main_cache, ): """ Initialize with the EC caches. Args: - ec_caches: dictionary of encoder cache + ec_main_cache """ # TODO: Implement this later for P2P feature return + @abstractmethod def start_load_caches( diff --git a/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py b/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py index 441666634534..b491c26404f3 100644 --- a/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/nixl_connector.py @@ -140,7 +140,7 @@ def __init__(self, vllm_config: VllmConfig, role: ECConnectorRole): def register_encoder_cache( self, - ec_cache: torch.Tensor, + ec_cache, ): """Register encoder cache tensors with NIXL.""" assert self.connector_worker is not None @@ -483,6 +483,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.encoder_cache: Optional[torch.Tensor] = None self.enc_base_addr = 0 self.enc_token_bytes = 0 + self.hidden_size = vllm_config.model_config.get_hidden_size() + logger.debug(f"hero: self.hidden_size: {self.hidden_size}") self._registered_descs: list[Any] = [] # TODO: find a more elegant way to store & manage mm_base_addr self._ENCODER_MM_BASE_ADDRS: dict[EngineId, dict[MMHash, int]] = {} @@ -531,14 +533,15 @@ def __del__(self): if self._nixl_handshake_listener_t: self._nixl_handshake_listener_t.join(timeout=0) - def register_encoder_cache(self, encoder_cache: torch.Tensor): + def register_encoder_cache(self, encoder_cache): """Register the main encoder cache tensor with NIXL.""" - self.encoder_cache = encoder_cache - self.enc_base_addr = encoder_cache.data_ptr() + self.ec_main_cache = encoder_cache + self.encoder_cache = encoder_cache.pool + self.enc_base_addr = encoder_cache.pool.data_ptr() self.enc_token_bytes = ( - encoder_cache[0].numel() * encoder_cache.element_size() + encoder_cache.pool[0].numel() * encoder_cache.pool.element_size() ) - enc_size_bytes = encoder_cache.numel() * encoder_cache.element_size() + enc_size_bytes = encoder_cache.pool.numel() * encoder_cache.pool.element_size() caches_data = [(self.enc_base_addr, enc_size_bytes, 0, "")] descs = self.nixl_wrapper.get_reg_descs( @@ -775,17 +778,19 @@ def register_encoder_recv_tensor( ): """Register a receive tensor for encoder cache transfer.""" base_addr = recv_tensor.data_ptr() - size_bytes = recv_tensor.numel() * recv_tensor.element_size() + # size_bytes = recv_tensor.numel() * recv_tensor.element_size() - self.device_id = max(recv_tensor.get_device(), 0) - caches_data = [(base_addr, size_bytes, self.device_id, "")] + # self.device_id = max(recv_tensor.get_device(), 0) + # caches_data = [(base_addr, size_bytes, self.device_id, "")] - descs = self.nixl_wrapper.get_reg_descs( - caches_data, self.nixl_memory_type - ) - logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs) - logger.debug("Done registering descs") + # descs = self.nixl_wrapper.get_reg_descs( + # caches_data, self.nixl_memory_type + # ) + # logger.debug("Registering descs: %s", caches_data) + # self.nixl_wrapper.register_memory(descs) + # logger.debug("Done registering descs") + + descs = "dummy" # hero self._registered_mm_descs[mm_hash] = ( base_addr, descs, @@ -816,7 +821,7 @@ def start_load_caches( logger.debug(f"hero: start_load_caches: {metadata.reqs_to_recv.items()}") # Reference the encoder_cache - self._encoder_cache_dict = encoder_cache + # self._encoder_cache_dict = encoder_cache # hero # # First, register all receive tensors from encoder_cache @@ -864,7 +869,7 @@ def start_load_caches( # time.sleep(2) # logger.debug(f"hero: wait 2 for request ready before read mm") # Handshake completed, start async read transfer - self._read_mm_segments(mm_hash, meta) + self._read_mm_segments(mm_hash, meta, encoder_cache) if metadata.reqs_to_recv: # if not empty for mm_hash, meta in metadata.reqs_to_recv.items(): @@ -884,12 +889,14 @@ def start_load_caches( while not self._ready_requests.empty(): logger.debug(f"while not self._ready_requests.empty():") - self._read_mm_segments(*self._ready_requests.get_nowait()) + self._read_mm_segments(*self._ready_requests.get_nowait(), encoder_cache) # Add to requests waiting to be read self._reqs_to_send.update(metadata.reqs_to_send) - def _read_mm_segments(self, mm_hash: str, meta: ECReqMeta): + self._encoder_cache_dict = encoder_cache + + def _read_mm_segments(self, mm_hash: str, meta: ECReqMeta, encoder_cache): """Read encoder cache from remote via NIXL. Transfers the entire encoder cache tensor for the given mm_hash. @@ -909,7 +916,7 @@ def _read_mm_segments(self, mm_hash: str, meta: ECReqMeta): self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return - if mm_hash not in self._registered_mm_descs: + if mm_hash not in self._registered_mm_descs or mm_hash not in encoder_cache: if remote_engine_id not in self._remote_enc_base_addr: logger.error( "Remote encoder base addr for engine %s not found when " @@ -927,27 +934,37 @@ def _read_mm_segments(self, mm_hash: str, meta: ECReqMeta): base_addr, token_bytes = self._remote_enc_base_addr[remote_engine_id] # Derive hidden size from bytes-per-token and dtype element size. - elem_size = torch.tensor([], dtype=self.encoder_cache_dtype).element_size() - assert token_bytes % elem_size == 0, ( - f"enc_token_bytes {token_bytes} not divisible by element size " - f"{elem_size} for dtype {self.encoder_cache_dtype}" - ) - hidden_size = token_bytes // elem_size + # elem_size = torch.tensor([], dtype=self.encoder_cache_dtype).element_size() + # assert token_bytes % elem_size == 0, ( + # f"enc_token_bytes {token_bytes} not divisible by element size " + # f"{elem_size} for dtype {self.encoder_cache_dtype}" + # ) + # hidden_size = token_bytes // elem_size + hidden_size = self.hidden_size + + logger.info(f"hero: before recv_tensor empty mm_hash {mm_hash}") # Allocate local receive tensor and expose it to the encoder_cache dict. - recv_tensor = torch.empty( - (num_encoder_tokens, hidden_size), - device=self.device_type, - dtype=self.encoder_cache_dtype, - ) + # recv_tensor = torch.empty( + # (num_encoder_tokens, hidden_size), + # device=self.device_type, + # dtype=self.encoder_cache_dtype, + # ) + + recv_tensor = self.ec_main_cache.allocate(num_encoder_tokens, mm_hash=mm_hash) - assert self._encoder_cache_dict is not None - logger.debug(f"hero: self._encoder_cache_dict: {self._encoder_cache_dict}") - self._encoder_cache_dict[mm_hash] = recv_tensor + logger.info(f"hero: after recv_tensor empty mm_hash {mm_hash}") - logger.debug(f"hero: self._encoder_cache_dict after recv_tensor: {self._encoder_cache_dict}") + # assert self._encoder_cache_dict is not None + assert encoder_cache is not None + # logger.debug(f"hero: self._encoder_cache_dict: {self._encoder_cache_dict}") + # self._encoder_cache_dict[mm_hash] = recv_tensor + encoder_cache[mm_hash] = recv_tensor # hero + logger.debug(f"hero: encoder_cache.keys() after recv_tensor: len:{len(encoder_cache.keys()), encoder_cache.keys()}") + + # logger.debug(f"hero: self._encoder_cache_dict after recv_tensor: {self._encoder_cache_dict}") - logger.debug(f"hero: size: {recv_tensor.size(), self._encoder_cache_dict[mm_hash].size()} / recv_tensor for {mm_hash}: {self._encoder_cache_dict[mm_hash]}") + # logger.debug(f"hero: size: {recv_tensor.size(), self._encoder_cache_dict[mm_hash].size()} / recv_tensor for {mm_hash}: {self._encoder_cache_dict[mm_hash]}") logger.debug( "Allocating receive tensor for mm_hash %s with shape %s " @@ -1083,12 +1100,13 @@ def _copy_recv_to_encoder_cache(self, mm_hash: str): # hero: if self._encoder_cache_dict is not None: self._encoder_cache_dict[mm_hash] = recv_tensor.clone() + logger.debug(f"hero: self._encoder_cache_dict.keys(): len:{len(self._encoder_cache_dict.keys()), self._encoder_cache_dict.keys()}") logger.debug( "Copied received encoder cache for mm_hash %s, shape: %s", mm_hash, recv_tensor.shape, ) - logger.debug(f"hero: mm_hash {mm_hash} cloned tensor: {recv_tensor}") + # logger.debug(f"hero: mm_hash {mm_hash} cloned tensor: {recv_tensor}") logger.debug( "Encoder cache transfer completed for mm_hash %s, shape: %s", @@ -1134,8 +1152,9 @@ def _pop_done_transfers( for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": - self.nixl_wrapper.release_xfer_handle(handle) - self._release_mm_handle(mm_hash, handle) + # self.nixl_wrapper.release_xfer_handle(handle) + # self._release_mm_handle(mm_hash, handle) + xfer_state == "DONE" # hero elif xfer_state == "PROC": in_progress = True continue @@ -1152,7 +1171,7 @@ def _release_mm_handle(self, mm_hash: str, handle: int): """Release NIXL handles and deregister memory for a completed transfer.""" if (mm_hash, handle) not in self._xfer_side_mm_handle: return - + logger.debug(f"hero: _release_mm_handle") _, src_xfer_handle, dst_xfer_handle = self._xfer_side_mm_handle[ (mm_hash, handle) ] diff --git a/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py b/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py index c8388141dcc9..8e8eaabfc30c 100644 --- a/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py @@ -60,6 +60,21 @@ def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): else: raise ValueError("ec_transfer_config must be set for ECConnectorBase") + def register_encoder_cache( + self, + ec_main_cache, + ): + """ + Initialize with the EC caches. + Args: + ec_main_cache + """ + # TODO: Implement this later for P2P feature + return + + def get_mm_hash_addrs(self): + return + def start_load_caches(self, encoder_cache, **kwargs) -> None: """ Start loading the cache from the connector into vLLM's encoder cache. diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 3959e9a59a53..a2bd0186433e 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -94,6 +94,7 @@ def check_and_update_cache(self, request: Request, input_id: int) -> bool: # Cached but currently not referenced by any request if not self.cached[mm_hash]: num_tokens = self.freeable.pop(mm_hash) + logger.debug(f"hero: self.freeable.pop mm_hash self.freeable") self.num_freeable_slots -= num_tokens self.cached[mm_hash].add(request.request_id) @@ -156,6 +157,7 @@ def can_allocate( while num_tokens > self.num_free_slots: mm_hash, num_free_token = self.freeable.popitem(last=False) del self.cached[mm_hash] + logger.debug(f"hero: physically deleted cache for mm_hash {mm_hash}") self.freed.append(mm_hash) self.num_free_slots += num_free_token return True @@ -220,6 +222,7 @@ def free_encoder_input(self, request: Request, input_id: int) -> None: if not self.cached[mm_hash]: num_tokens = request.get_num_encoder_tokens(input_id) self.freeable[mm_hash] = num_tokens + logger.debug(f"self.freeable mm_hash {mm_hash}") self.num_freeable_slots += num_tokens def free(self, request: Request) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e92da4e8bb9f..3512cbb43b4f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -273,7 +273,7 @@ def allocate(self, size, mm_hash=None): self.offset = 0 tensor = self.pool[self.offset:self.offset + size] if mm_hash: - self.allocated_blocks[mm_hash] = (self.offset, size) + self.allocated_blocks[mm_hash] = {"offset": self.offset, "size": size} self.offset += size return tensor @@ -388,7 +388,8 @@ def __init__( ## nixl cache! # ================================================================== self.ec_main_cache: torch.Tensor | None = None - if has_ec_transfer() and get_ec_transfer().is_producer: + # if has_ec_transfer() and get_ec_transfer().is_producer: + if has_ec_transfer(): # For encoder, self.max_num_tokens => max number of encoder input if self.max_num_tokens > 0: self.ec_main_cache = CustomTensorPool( @@ -404,7 +405,7 @@ def __init__( self.ec_main_cache.pool.shape, self.ec_main_cache.pool.dtype, ) - get_ec_transfer().register_encoder_cache(self.ec_main_cache.pool) + get_ec_transfer().register_encoder_cache(self.ec_main_cache) else: logger.warning( "EC transfer producer enabled but max_encoder_len == 0; " @@ -2195,7 +2196,23 @@ def _execute_mm_encoder( pooled_tensor = self.ec_main_cache.allocate( tensor_size, mm_hash=mm_hash - ).view(output.shape) + ).view(output.shape) + + # evict cache when pool run out of space + logger.info(f"hero: self.ec_main_cache.allocated_blocks.items(): {self.ec_main_cache.allocated_blocks.items()}") + keys = list(self.ec_main_cache.allocated_blocks.keys()) + for mm_hash_allocated in keys: + pooled_tensor_start = self.ec_main_cache.offset + pooled_tensor_end = self.ec_main_cache.offset + tensor_size + cache_start = self.ec_main_cache.allocated_blocks[mm_hash_allocated]["offset"] + cache_end = self.ec_main_cache.allocated_blocks[mm_hash_allocated]["offset"] + self.ec_main_cache.allocated_blocks[mm_hash_allocated]["size"] + + if (pooled_tensor_start < cache_start and pooled_tensor_end < cache_start) or (pooled_tensor_start > cache_end and pooled_tensor_end > cache_end): + continue + else: + logger.info(f"hero: from encoder_cache pop mm_hash {mm_hash_allocated}; {pooled_tensor_start, pooled_tensor_end, cache_start, cache_end}") + self.encoder_cache.pop(mm_hash_allocated, None) + self.ec_main_cache.allocated_blocks.pop(mm_hash_allocated, None) # Copy encoder output to pooled location pooled_tensor.copy_(output) @@ -2210,7 +2227,7 @@ def _execute_mm_encoder( is_embed=pos_info.is_embed, ) logger.debug("Finish execute for mm hash %s", mm_hash) - logger.debug(f"hero: size: {self.encoder_cache[mm_hash].size()} / self.encoder_cache[mm_hash] for {mm_hash}: {self.encoder_cache[mm_hash]}") + # logger.debug(f"hero: size: {self.encoder_cache[mm_hash].size()} / self.encoder_cache[mm_hash] for {mm_hash}: {self.encoder_cache[mm_hash]}") self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) return encoder_outputs @@ -2270,6 +2287,7 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) + logger.debug(f"hero: self.encoder_cache key: {self.encoder_cache.keys()}") assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: