diff --git a/examples/online_serving/disaggregated_encoder/mooncake_connector/disagg_1e1p1d_example.sh b/examples/online_serving/disaggregated_encoder/mooncake_connector/disagg_1e1p1d_example.sh new file mode 100755 index 000000000000..06890c4939a5 --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/mooncake_connector/disagg_1e1p1d_example.sh @@ -0,0 +1,240 @@ +#!/bin/bash +set -euo pipefail + +declare -a PIDS=() +MOONCAKE_MASTER_PID="" + +############################################################################### +# Configuration -- override via env before running +############################################################################### +MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}" +LOG_PATH="${LOG_PATH:-./logs}" +mkdir -p $LOG_PATH + +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_PORT="${PREFILL_PORT:-19535}" +DECODE_PORT="${DECODE_PORT:-19536}" +PROXY_PORT="${PROXY_PORT:-10001}" + +GPU_E="${GPU_E:-0}" +GPU_P="${GPU_P:-1}" +GPU_D="${GPU_D:-2}" + +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout +NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark + +MOONCAKE_MASTER_PORT=50051 +MOONCAKE_METADATA_PORT=8080 +MOONCAKE_MASTER_IP="localhost" # producer +MOONCAKE_STORE_INSTANCE_IP="localhost" # consumer +MOONCAKE_GLOBAL_SEGMENT_SIZE=$((30 * 1073741824)) # 30 GB +MOONCAKE_LOCAL_BUFFER_SIZE=$((1 * 1073741824)) # 1 GB +MOONCAKE_REPLICA_NUM=1 +MOONCAKE_FAST_TRANSFER=true +MOONCAKE_FAST_TRANSFER_BUFFER_SIZE=$((30 * 1073741824)) # 30 GB + +export UCX_TLS=all +export UCX_NET_DEVICES=all + +############################################################################### +# Helpers +############################################################################### +START_TIME=$(date +"%Y%m%d_%H%M%S") +MOONCAKE_MASTER_LOG="$LOG_PATH/mooncake_master_$START_TIME.log" +ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log +P_LOG=$LOG_PATH/p_${START_TIME}.log +D_LOG=$LOG_PATH/d_${START_TIME}.log +PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log + +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:$port/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Stopping everything…" + trap - INT TERM USR1 # prevent re-entrancy + + # Kill all tracked PIDs + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" 2>/dev/null + fi + done + + # Wait a moment for graceful shutdown + wait "${PIDS[@]}" 2>/dev/null || true + + # Force kill any remaining processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -9 "$pid" 2>/dev/null + fi + done + + echo "Force killing mooncake processes" + pkill -f "mooncake_master" + + echo "All processes stopped." + exit 0 +} + +trap cleanup INT +trap cleanup USR1 +trap cleanup TERM + +############################################################################### +# Initialize Mooncake +# Read more about Mooncake config at +# https://kvcache-ai.github.io/Mooncake/deployment/mooncake-store-deployment-guide.html +############################################################################### +mooncake_master \ + --rpc_port $MOONCAKE_MASTER_PORT \ + --enable_http_metadata_server=true \ + --http_metadata_server_host=0.0.0.0 \ + --http_metadata_server_port=$MOONCAKE_METADATA_PORT \ + --rpc_thread_num 8 \ + --default_kv_lease_ttl 5000 \ + --eviction_ratio 0.05 \ + --eviction_high_watermark_ratio 0.9 \ + >"$MOONCAKE_MASTER_LOG" 2>&1 & + +export MC_MS_AUTO_DISC=0 + +############################################################################### +# Encoder worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.7 \ + --port "$ENCODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 4096 \ + --max-num-seqs 128 \ + --ec-transfer-config "{ + \"ec_connector\": \"ECMooncakeStorageConnector\", + \"ec_role\": \"ec_producer\", + \"ec_connector_extra_config\": { + \"local_hostname\": \"$MOONCAKE_MASTER_IP\", + \"metadata_server\": \"http://localhost:$MOONCAKE_METADATA_PORT/metadata\", + \"global_segment_size\": $MOONCAKE_GLOBAL_SEGMENT_SIZE, + \"local_buffer_size\": $MOONCAKE_LOCAL_BUFFER_SIZE, + \"protocol\": \"tcp\", + \"device_name\": \"\", + \"master_server_address\": \"localhost:$MOONCAKE_MASTER_PORT\", + \"replica_num\": $MOONCAKE_REPLICA_NUM, + \"fast_transfer\": $MOONCAKE_FAST_TRANSFER, + \"fast_transfer_buffer_size\": $MOONCAKE_FAST_TRANSFER_BUFFER_SIZE + } + }" \ + >"${ENC_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Prefill worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_P" \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \ +vllm serve "$MODEL" \ + --gpu-memory-utilization 0.8 \ + --port "$PREFILL_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --ec-transfer-config "{ + \"ec_connector\": \"ECMooncakeStorageConnector\", + \"ec_role\": \"ec_consumer\", + \"ec_connector_extra_config\": { + \"local_hostname\": \"$MOONCAKE_STORE_INSTANCE_IP\", + \"metadata_server\": \"http://localhost:$MOONCAKE_METADATA_PORT/metadata\", + \"global_segment_size\": 0, + \"local_buffer_size\": $MOONCAKE_LOCAL_BUFFER_SIZE, + \"protocol\": \"tcp\", + \"device_name\": \"\", + \"master_server_address\": \"localhost:$MOONCAKE_MASTER_PORT\", + \"replica_num\": $MOONCAKE_REPLICA_NUM, + \"fast_transfer\": $MOONCAKE_FAST_TRANSFER, + \"fast_transfer_buffer_size\": $MOONCAKE_FAST_TRANSFER_BUFFER_SIZE + } + }" \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_producer" + }' \ + >"${P_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Decode worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_D" \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \ +vllm serve "$MODEL" \ + --gpu-memory-utilization 0.7 \ + --port "$DECODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer" + }' \ + >"${D_LOG}" 2>&1 & + +PIDS+=($!) + +# Wait for workers +wait_for_server $ENCODE_PORT +wait_for_server $PREFILL_PORT +wait_for_server $DECODE_PORT + +############################################################################### +# Proxy +############################################################################### +python ../disagg_epd_proxy.py \ + --host "0.0.0.0" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "http://localhost:$PREFILL_PORT" \ + --decode-servers-urls "http://localhost:$DECODE_PORT" \ + >"${PROXY_LOG}" 2>&1 & + +PIDS+=($!) + +wait_for_server $PROXY_PORT +echo "All services are up!" + +############################################################################### +# Benchmark +############################################################################### +echo "Running benchmark (stream)..." +vllm bench serve \ + --model $MODEL \ + --dataset-name random-mm \ + --num-prompts 100 \ + --random-input-len 150 \ + --random-output-len 100 \ + --random-range-ratio 0.0 \ + --random-mm-base-items-per-request 1 \ + --random-mm-num-mm-items-range-ratio 0 \ + --random-mm-limit-mm-per-prompt '{"image":2,"video":0}' \ + --random-mm-bucket-config '{(700, 728, 1): 1.0}' \ + --ignore-eos \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --port $PROXY_PORT + +# cleanup +echo "cleanup..." +cleanup diff --git a/examples/online_serving/disaggregated_encoder/mooncake_connector/disagg_1e1pd_example.sh b/examples/online_serving/disaggregated_encoder/mooncake_connector/disagg_1e1pd_example.sh new file mode 100755 index 000000000000..c2397b194125 --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/mooncake_connector/disagg_1e1pd_example.sh @@ -0,0 +1,201 @@ +#!/bin/bash +set -euo pipefail + +declare -a PIDS=() + +############################################################################### +# Configuration -- override via env before running +############################################################################### +MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}" +LOG_PATH="${LOG_PATH:-./logs}" +mkdir -p $LOG_PATH + +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}" +PROXY_PORT="${PROXY_PORT:-10001}" + +GPU_E="${GPU_E:-0}" +GPU_PD="${GPU_PD:-1}" + +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout +NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark + +MOONCAKE_MASTER_PORT=50051 +MOONCAKE_METADATA_PORT=8080 +MOONCAKE_MASTER_IP="localhost" # producer +MOONCAKE_STORE_INSTANCE_IP="localhost" # consumer +MOONCAKE_GLOBAL_SEGMENT_SIZE=$((30 * 1073741824)) +MOONCAKE_LOCAL_BUFFER_SIZE=$((1 * 1073741824)) +MOONCAKE_REPLICA_NUM=1 +MOONCAKE_TRANSFER_BUFFER_SIZE=$((30 * 1073741824)) + +############################################################################### +# Helpers +############################################################################### +START_TIME=$(date +"%Y%m%d_%H%M%S") +ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log +PD_LOG=$LOG_PATH/pd_${START_TIME}.log +PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log +MOONCAKE_MASTER_LOG="$LOG_PATH/mooncake_master_$START_TIME.log" + +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:$port/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Stopping everything…" + trap - INT TERM USR1 # prevent re-entrancy + + # Kill all tracked PIDs + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" 2>/dev/null + fi + done + + # Wait a moment for graceful shutdown + wait "${PIDS[@]}" 2>/dev/null || true + + # Force kill any remaining processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -9 "$pid" 2>/dev/null + fi + done + + echo "Force killing mooncake processes" + pkill -f "mooncake_master" + + echo "All processes stopped." + exit 0 +} + +trap cleanup INT +trap cleanup USR1 +trap cleanup TERM + +############################################################################### +# Initialize Mooncake +# Read more about Mooncake config at +# https://kvcache-ai.github.io/Mooncake/deployment/mooncake-store-deployment-guide.html +############################################################################### +mooncake_master \ + --rpc_port $MOONCAKE_MASTER_PORT \ + --enable_http_metadata_server=true \ + --http_metadata_server_host=0.0.0.0 \ + --http_metadata_server_port=$MOONCAKE_METADATA_PORT \ + --rpc_thread_num 8 \ + --default_kv_lease_ttl 5000 \ + --eviction_ratio 0.05 \ + --eviction_high_watermark_ratio 0.9 \ + >"$MOONCAKE_MASTER_LOG" 2>&1 & + +export MC_MS_AUTO_DISC=0 + +############################################################################### +# Encoder worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.8 \ + --port "$ENCODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 65536 \ + --max-num-seqs 128 \ + --ec-transfer-config "{ + \"ec_connector\": \"ECMooncakeStorageConnector\", + \"ec_role\": \"ec_producer\", + \"ec_connector_extra_config\": { + \"local_hostname\": \"$MOONCAKE_MASTER_IP\", + \"metadata_server\": \"http://localhost:$MOONCAKE_METADATA_PORT/metadata\", + \"global_segment_size\": $MOONCAKE_GLOBAL_SEGMENT_SIZE, + \"local_buffer_size\": $MOONCAKE_LOCAL_BUFFER_SIZE, + \"protocol\": \"tcp\", + \"device_name\": \"\", + \"master_server_address\": \"localhost:$MOONCAKE_MASTER_PORT\", + \"replica_num\": $MOONCAKE_REPLICA_NUM, + \"transfer_buffer_size\": $MOONCAKE_TRANSFER_BUFFER_SIZE + } + }" \ + >"${ENC_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Prefill+Decode worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.8 \ + --port "$PREFILL_DECODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --ec-transfer-config "{ + \"ec_connector\": \"ECMooncakeStorageConnector\", + \"ec_role\": \"ec_consumer\", + \"ec_connector_extra_config\": { + \"local_hostname\": \"$MOONCAKE_STORE_INSTANCE_IP\", + \"metadata_server\": \"http://localhost:$MOONCAKE_METADATA_PORT/metadata\", + \"global_segment_size\": 0, + \"local_buffer_size\": $MOONCAKE_LOCAL_BUFFER_SIZE, + \"protocol\": \"tcp\", + \"device_name\": \"\", + \"master_server_address\": \"localhost:$MOONCAKE_MASTER_PORT\", + \"replica_num\": $MOONCAKE_REPLICA_NUM, + \"transfer_buffer_size\": $MOONCAKE_TRANSFER_BUFFER_SIZE + } + }" \ + >"${PD_LOG}" 2>&1 & + +PIDS+=($!) + +# Wait for workers +wait_for_server $ENCODE_PORT +wait_for_server $PREFILL_DECODE_PORT + +############################################################################### +# Proxy +############################################################################### +python ../disagg_epd_proxy.py \ + --host "0.0.0.0" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "disable" \ + --decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \ + >"${PROXY_LOG}" 2>&1 & + +PIDS+=($!) + +wait_for_server $PROXY_PORT +echo "All services are up!" + +############################################################################### +# Benchmark +############################################################################### +vllm bench serve \ + --model $MODEL \ + --dataset-name random-mm \ + --num-prompts 100 \ + --random-input-len 150 \ + --random-output-len 100 \ + --random-range-ratio 0.0 \ + --random-mm-base-items-per-request 1 \ + --random-mm-num-mm-items-range-ratio 0 \ + --random-mm-limit-mm-per-prompt '{"image":2,"video":0}' \ + --random-mm-bucket-config '{(700, 728, 1): 1.0}' \ + --ignore-eos \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --port $PROXY_PORT + +# cleanup +echo "cleanup..." +cleanup \ No newline at end of file diff --git a/tests/v1/ec_connector/unit/test_mooncake_store.py b/tests/v1/ec_connector/unit/test_mooncake_store.py new file mode 100644 index 000000000000..487f9a047d92 --- /dev/null +++ b/tests/v1/ec_connector/unit/test_mooncake_store.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ctypes +import json +from dataclasses import dataclass +from unittest import mock + +import pytest +import torch + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_lookup_buffer.mooncake_store import ECMooncakeStore +from vllm.distributed.ec_transfer.utils.tensor_memory_pool import ( + InsufficientMemoryError, + TensorMemoryPool, +) +from vllm.platforms import current_platform + +DEFAULT_BUFFER_SIZE = 1024 + + +# Fake implementation of MooncakeDistributedStore for testing +class FakeMooncakeDistributedStore: + """A fake implementation of MooncakeDistributedStore used as a test double. + + This mock class is used in unit tests to simulate the behavior of the + MooncakeDistributedStore without requiring the actual Mooncake library. + """ + + def __init__(self): + self.data = {} # key -> bytes or tensors + self.registered_buffers: set[tuple[int, int]] = set() + self.remove_calls = [] # Track remove_by_regex calls + + def setup( + self, + local_hostname, + metadata_server, + global_segment_size, + local_buffer_size, + protocol, + device_name, + master_server_address, + ): + pass # No-op for fake + + def close(self): + pass # No-op + + def batch_is_exist(self, keys): + return [k in self.data for k in keys] + + def get_batch(self, keys): + return [self.data.get(k) for k in keys] + + # List of bytes read for each operation (positive = success, negative = error) + def batch_get_into(self, keys, addrs, sizes): + results = [] + for key, addr, size in zip(keys, addrs, sizes): + if key in self.data and any( + addr >= baddr and addr + size <= baddr + bsize + for baddr, bsize in self.registered_buffers + ): + # Simulate copy: put data into buffer + buffer = (ctypes.c_char * len(self.data[key])).from_buffer( + bytearray(self.data[key]) + ) + ctypes.memmove(addr, ctypes.addressof(buffer), size) + results.append(size) + else: + results.append(-1) + return results + + def put_batch(self, keys, values, replica_config): + for key, value in zip(keys, values): + self.data[key] = value + + def batch_put_from(self, keys, addrs, sizes, replica_config): + for key, addr, size in zip(keys, addrs, sizes): + if any( + addr >= baddr and addr + size <= baddr + bsize + for baddr, bsize in self.registered_buffers + ): + data: bytes = ctypes.string_at(addr, size) + self.data[key] = data[:size] + + def register_buffer(self, addr, size): + self.registered_buffers.add((addr, size)) + + def unregister_buffer(self, addr, size): + self.registered_buffers.remove((addr, size)) + + def remove_by_regex(self, pattern): + import regex as re + + regex = re.compile(pattern) + count = 0 + for key in list(self.data.keys()): + if regex.match(key): + del self.data[key] + count += 1 + self.remove_calls.append(pattern) + return count + + +# Fake ReplicateConfig +@dataclass +class FakeReplicateConfig: + replica_num: int = 1 + + +@pytest.fixture +def mock_inner_mooncake_store(monkeypatch): + fake_store = FakeMooncakeDistributedStore() + monkeypatch.setattr("mooncake.store.MooncakeDistributedStore", lambda: fake_store) + monkeypatch.setattr("mooncake.store.ReplicateConfig", FakeReplicateConfig) + return fake_store + + +@pytest.fixture +def vllm_config(): + config = mock.Mock(spec=VllmConfig) + config.ec_transfer_config = mock.Mock() + config.device_config.device = current_platform.device_type + config.ec_transfer_config.ec_connector_extra_config = { + "local_hostname": "test_host", + "metadata_server": "test_meta", + "global_segment_size": DEFAULT_BUFFER_SIZE, + "local_buffer_size": DEFAULT_BUFFER_SIZE, + "protocol": "tcp", + "device_name": "test_device", + "master_server_address": "test_master", + "storage_root_dir": "", + "transfer_timeout": 5, + "replica_num": 2, + "fast_transfer": False, + "fast_transfer_buffer_size": DEFAULT_BUFFER_SIZE, + } + return config + + +@pytest.fixture +def ec_mooncake_store(vllm_config, mock_inner_mooncake_store): + store = ECMooncakeStore(vllm_config) + yield store + try: + store.close() + except RuntimeError as e: + if "Event loop is closed" in str(e): + # exception for test_close() + return + else: + raise + + +def test_init(vllm_config, mock_inner_mooncake_store): + # Mock methods + mock_inner_mooncake_store.setup = mock.MagicMock(name="setup") + + store = ECMooncakeStore(vllm_config) + assert store.config.local_hostname == "test_host" + assert store.config.replica_num == 2 + assert not store.config.fast_transfer + mock_inner_mooncake_store.setup.assert_called_once() + store.close() + + +def test_init_with_fast_transfer(monkeypatch, vllm_config, mock_inner_mooncake_store): + # Mock methods + mock_inner_mooncake_store.register_buffer = mock.MagicMock(name="register_buffer") + mock_inner_mooncake_store.unregister_buffer = mock.MagicMock( + name="unregister_buffer" + ) + tensorpool = mock.Mock(spec=TensorMemoryPool) + monkeypatch.setattr( + "vllm.distributed.ec_transfer.utils.tensor_memory_pool.TensorMemoryPool", + lambda max_block_size: tensorpool, + ) + + # Modify config to enable fast_transfer + vllm_config.ec_transfer_config.ec_connector_extra_config["fast_transfer"] = True + + store = ECMooncakeStore(vllm_config) + assert store.config.fast_transfer + mock_inner_mooncake_store.register_buffer.assert_called_with( + mock.ANY, DEFAULT_BUFFER_SIZE + ) + store.close() + mock_inner_mooncake_store.unregister_buffer.assert_called_with( + mock.ANY, DEFAULT_BUFFER_SIZE + ) + # Make sure it registers & unregisters the same buffer + assert ( + mock_inner_mooncake_store.register_buffer.call_args + == mock_inner_mooncake_store.unregister_buffer.call_args + ) + + +def test_batch_exists(ec_mooncake_store, mock_inner_mooncake_store): + mock_inner_mooncake_store.data = {"key1": b"data1", "key2": b"data2"} + exists = ec_mooncake_store.batch_exists(["key1", "key3", "key2"]) + assert exists == [True, False, True] + exists = ec_mooncake_store.batch_exists([]) + assert exists == [] + + +def test_batch_get_non_fast(ec_mooncake_store, mock_inner_mooncake_store, vllm_config): + # Prepare serialized data + tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) + meta = { + "shape": list(tensor.shape), + "original_dtype": str(tensor.dtype), + "serialized_dtype": "float32", + } + + meta_bytes = json.dumps(meta).encode("utf-8") + len_bytes = len(meta_bytes).to_bytes(4, "big") + data_bytes = tensor.cpu().numpy().tobytes() + serialized = len_bytes + meta_bytes + data_bytes + + mock_inner_mooncake_store.data = {"key1": serialized, "key2": None} + + results = ec_mooncake_store.batch_get( + ["key1", "key2"], device=vllm_config.device_config.device + ) + assert torch.equal(results[0].cpu(), tensor.cpu()) + assert results[1] is None + + +def test_batch_put_non_fast(ec_mooncake_store, mock_inner_mooncake_store, vllm_config): + tensors = [ + torch.randn( + (2, 2), dtype=torch.bfloat16, device=vllm_config.device_config.device + ), + torch.randn((1, 4), dtype=torch.float32), + torch.tensor( + [[1, 2]], dtype=torch.int32, device=vllm_config.device_config.device + ), + ] + keys = ["key1", "key2", "key3"] + + ec_mooncake_store.batch_put(keys, tensors) + ec_mooncake_store.wait_for_put() + + assert "key1" in mock_inner_mooncake_store.data + assert "key2" in mock_inner_mooncake_store.data + assert "key3" in mock_inner_mooncake_store.data + + # Verify deserialization + stored1 = mock_inner_mooncake_store.data["key1"] + len_meta = int.from_bytes(stored1[:4], "big") + meta = json.loads(stored1[4 : 4 + len_meta].decode("utf-8")) + assert meta["original_dtype"] == "torch.bfloat16" + + results = ec_mooncake_store.batch_get( + ["key1", "key2", "key3"], device=vllm_config.device_config.device + ) + + assert torch.equal(results[0].cpu(), tensors[0].cpu()) + assert torch.equal(results[1].cpu(), tensors[1].cpu()) + assert torch.equal(results[2].cpu(), tensors[2].cpu()) + + +def test_batch_get_zero_copy(monkeypatch, vllm_config, mock_inner_mooncake_store): + # Enable fast_transfer + vllm_config.ec_transfer_config.ec_connector_extra_config["fast_transfer"] = True + + store = ECMooncakeStore(vllm_config) + + # Prepare metadata + meta = {"shape": [2, 2], "dtype": "torch.float32"} + meta_bytes = json.dumps(meta).encode("utf-8") + value1 = torch.randn((2, 2)) + value1_bytes = value1.numpy().tobytes() + mock_inner_mooncake_store.data = { + "key1_metadata": meta_bytes, + "key1": value1_bytes, + } + + results = store.batch_get(["key1", "key2"], device=vllm_config.device_config.device) + assert torch.equal(value1.cpu(), results[0].cpu()) + assert results[1] is None + + store.close() + + +def test_batch_put_zero_copy(monkeypatch, vllm_config, mock_inner_mooncake_store): + # Enable fast_transfer + vllm_config.ec_transfer_config.ec_connector_extra_config["fast_transfer"] = True + + store = ECMooncakeStore(vllm_config) + + tensors = [ + torch.tensor( + [[1, 2]], dtype=torch.int32, device=vllm_config.device_config.device + ), + torch.tensor( + [[3.0, 4.0]], dtype=torch.float32, device=vllm_config.device_config.device + ), + ] + keys = ["key1", "key2"] + + store.batch_put(keys, tensors) + store.wait_for_put() + + assert ( + mock_inner_mooncake_store.data.get("key1") == tensors[0].cpu().numpy().tobytes() + ) + assert ( + mock_inner_mooncake_store.data.get("key2") == tensors[1].cpu().numpy().tobytes() + ) + assert store.metadata_key("key1") in mock_inner_mooncake_store.data + assert store.metadata_key("key2") in mock_inner_mooncake_store.data + + store.close() + + +def test_pool_eviction(monkeypatch, vllm_config, mock_inner_mooncake_store): + # Enable fast_transfer + vllm_config.ec_transfer_config.ec_connector_extra_config["fast_transfer"] = True + + store = ECMooncakeStore(vllm_config) + + evict_tensor = torch.randn( + (4, 4), dtype=torch.float32, device=vllm_config.device_config.device + ) + store.batch_put(["evict_key"], [evict_tensor]) + store.wait_for_put() + + # Trigger allocation with eviction, 16 * 16 * 4 = 1024 + new_tensor = torch.randn( + (16, 16), dtype=torch.float32, device=vllm_config.device_config.device + ) + with ( + mock.patch.object(store.tensor_pool, "_allocate") as mock_allocate, + mock.patch.object(store.tensor_pool, "free") as mock_free, + ): + mock_allocate.side_effect = [ + InsufficientMemoryError("Not enough memory"), + mock.Mock(), + ] + + store.batch_put(["new_key"], [new_tensor]) + store.wait_for_put() + + mock_free.assert_called_once() + assert mock_allocate.call_count == 2 + store.close() + + +def test_close(ec_mooncake_store, mock_inner_mooncake_store): + mock_inner_mooncake_store.close = mock.MagicMock(name="close") + ec_mooncake_store.close() + mock_inner_mooncake_store.close.assert_called_once() diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py index 2b7b14d89b8a..84d292d6eb63 100644 --- a/vllm/distributed/ec_transfer/ec_connector/base.py +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -159,6 +159,14 @@ def save_caches( """ pass + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + to ensure that async operations are is complete before + notifying the proxy that processing image is finished. + """ + return + def get_finished( self, finished_req_ids: set[str] ) -> tuple[set[str] | None, set[str] | None]: diff --git a/vllm/distributed/ec_transfer/ec_connector/factory.py b/vllm/distributed/ec_transfer/ec_connector/factory.py index 32f36ffbb14d..3043e48910cd 100644 --- a/vllm/distributed/ec_transfer/ec_connector/factory.py +++ b/vllm/distributed/ec_transfer/ec_connector/factory.py @@ -83,3 +83,9 @@ def get_connector_class( "vllm.distributed.ec_transfer.ec_connector.example_connector", "ECExampleConnector", ) + +ECConnectorFactory.register_connector( + "ECMooncakeStorageConnector", + "vllm.distributed.ec_transfer.ec_connector.mooncake_storage_connector", + "ECMooncakeStorageConnector", +) diff --git a/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py b/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py new file mode 100644 index 000000000000..46da9bd91ec3 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorBase, + ECConnectorMetadata, + ECConnectorRole, +) +from vllm.distributed.ec_transfer.ec_lookup_buffer.mooncake_store import ( + ECMooncakeStore, + MooncakeLoadMeta +) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MMMeta: + mm_hash: str + num_token: int + + @staticmethod + def make_meta(mm_hash, num_token) -> "MMMeta": + return MMMeta(mm_hash=mm_hash, num_token=num_token) + + +@dataclass +class ECMooncakeStorageConnectorMetadata(ECConnectorMetadata): + mm_datas: list[MMMeta] + + def __init__(self): + self.mm_datas = [] + + def add_mm_data(self, mm_data: MMMeta): + self.mm_datas.append(mm_data) + + +class ECMooncakeStorageConnector(ECConnectorBase): + def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + # mm_hash -> num_tokens + self._mm_datas_need_loads: dict[str, int] = {} + self.store = ECMooncakeStore(vllm_config) + + def start_load_caches(self, encoder_cache, **kwargs) -> None: + """ + Start loading the cache from the connector into vLLM's encoder cache. + + This method loads the encoder cache based on metadata provided by the scheduler. + It is called before `_gather_mm_embeddings` for the EC Connector. For EC, + the `encoder_cache` and `mm_hash` are stored in `kwargs`. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + kwargs (dict): Additional keyword arguments for the connector. + """ + + # Get the metadata + metadata: ECConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, ECMooncakeStorageConnectorMetadata) + assert encoder_cache is not None + if not metadata.mm_datas: + return + + load_metas = [ + MooncakeLoadMeta( + key=mm_data.mm_hash, + num_token=mm_data.num_token + ) + for mm_data in metadata.mm_datas + if mm_data.mm_hash not in encoder_cache + ] + device = self._vllm_config.device_config.device + tensors = self.store.batch_get(load_metas, device) + + for load_meta, ec_cache in zip(load_metas, tensors): + encoder_cache[load_meta.key] = ec_cache + if ec_cache is None: + logger.error("Load failed for %s", load_meta.key) + logger.debug("Load tensor for %s successfully", load_meta.key) + + def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None: + """ + Save the encoder cache to the connector. + + This method saves the encoder cache from the worker's local storage + to shared storage or another external connector. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + mm_hash (str): The hash of the multimodal data whose cache is being saved. + kwargs (dict): Additional keyword arguments for the connector. + """ + if not self.is_producer: + return + assert encoder_cache is not None + assert mm_hash is not None + self.store.batch_put([mm_hash], [encoder_cache[mm_hash]]) + + def wait_for_save(self): + self.store.wait_for_put() + + def has_caches( + self, + request: "Request", + ) -> list[bool]: + """ + Check if cache exist externally for each mm_data of request + + Args: + request (Request): the request object. + + Returns: + List of bool indicate that ith mm_data exist in cache or not + """ + mm_hashes = [feature.identifier for feature in request.mm_features] + return self.store.batch_exists(mm_hashes) + + def update_state_after_alloc( + self, + request: "Request", + index: int, + ) -> None: + """ + Update ECConnector state after encoder cache allocation. + """ + mm_hash = request.mm_features[index].identifier + num_encoder_token = request.get_num_encoder_tokens(index) + # Insert mm_hash only if this block has not been recorded yet. + self._mm_datas_need_loads[mm_hash] = num_encoder_token + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> ECConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + This only build for load mm_data only + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = ECMooncakeStorageConnectorMetadata() + for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): + meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) + self._mm_datas_need_loads.clear() + return meta diff --git a/vllm/distributed/ec_transfer/ec_lookup_buffer/__init__.py b/vllm/distributed/ec_transfer/ec_lookup_buffer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py b/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py new file mode 100644 index 000000000000..e52c8d4d6890 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains a new class `MooncakeStore` that allows developers to +think of EC cache transfer operations as putting new EC cache entries +into a remote ECStore-based lookup buffer and getting existing EC caches +from this remote lookup buffer. +""" + +import asyncio +import json +import math +import multiprocessing +import os +import threading +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any + +import numpy as np +import regex as re +import torch + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.utils.tensor_memory_pool import ( + TensorMemoryPool, +) +from vllm.logger import init_logger + +METADATA_SUFFIX = "_metadata" +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB +DEFAULT_TENSOR_POOL_SIZE = 1073741824 # 1.0 GiB + +logger = init_logger(__name__) + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + storage_root_dir: str + transfer_timeout: int + replica_num: int + transfer_buffer_size: int + + @staticmethod + def from_config(config: dict[str, Any]) -> "MooncakeStoreConfig": + """Load the config from a JSON file.""" + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname", "localhost"), + metadata_server=config.get("metadata_server", ""), + global_segment_size=config.get( + "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE + ), + local_buffer_size=config.get( + "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE + ), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address", ""), + storage_root_dir=config.get("storage_root_dir", ""), + transfer_timeout=int(config.get("transfer_timeout", 1)), + replica_num=int(config.get("replica_num", 1)), + transfer_buffer_size=int( + config.get("transfer_buffer_size", DEFAULT_TENSOR_POOL_SIZE) + ), + ) + + +@dataclass +class MooncakeLoadMeta: + key: str + num_token: int + +class ECMooncakeStore: + """ + Currently, it only supports zero-copy get/put with + following data path gpu->cpu->cpu->gpu + """ + + def __init__(self, vllm_config: "VllmConfig"): + try: + from mooncake.store import MooncakeDistributedStore, ReplicateConfig + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector." + ) from e + + try: + if vllm_config.ec_transfer_config is None: + raise ValueError("ec_transfer_config must be set for ECConnectorBase") + + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.from_config( + vllm_config.ec_transfer_config.ec_connector_extra_config + ) + logger.debug("Mooncake Configuration loaded successfully.") + + # Check if storage_root_dir exists and set environment variable + if ( + self.config.storage_root_dir is not None + and self.config.storage_root_dir != "" + ): + os.environ["MOONCAKE_STORAGE_ROOT_DIR"] = self.config.storage_root_dir + logger.info( + "Set MOONCAKE_STORAGE_ROOT_DIR to: %s", self.config.storage_root_dir + ) + + logger.info("Setting up Mooncake store with parameters:") + logger.info(" local_hostname: %s", self.config.local_hostname) + logger.info(" metadata_server: %s", self.config.metadata_server) + logger.info(" global_segment_size: %s", self.config.global_segment_size) + logger.info(" local_buffer_size: %s", self.config.local_buffer_size) + logger.info(" protocol: %s", self.config.protocol) + logger.info(" device_name: %s", self.config.device_name) + logger.info( + " master_server_address: %s", self.config.master_server_address + ) + logger.info(" transfer_timeout: %s", self.config.transfer_timeout) + logger.info(" replica_num: %s", self.config.replica_num) + logger.info( + " transfer_buffer_size: %s", self.config.transfer_buffer_size + ) + + self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error("An error occurred while loading the configuration: %s", exc) + raise + + # Initialize ReplicateConfig + self.replica_config = ReplicateConfig() + self.replica_config.replica_num = self.config.replica_num + + logger.info("MooncakeConnector initialized successfully.") + + self.tensor_pool = TensorMemoryPool( + max_block_size=self.config.transfer_buffer_size + ) + self.pool_lock = threading.Lock() + self.store.register_buffer( + self.tensor_pool.base_address, self.config.transfer_buffer_size + ) + + # Put async init + # queue of unfinished put requests stored by keys + self.put_queue: set[str] = set() + self.put_queue_cv = asyncio.Condition() + self.put_loop = asyncio.new_event_loop() + self.put_thread = threading.Thread( + target=self.put_loop.run_forever, daemon=True + ) + self.put_thread.start() + + max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) + self.io_executor = ThreadPoolExecutor(max_workers=max_workers) + + # model config + self.embed_size = vllm_config.model_config.get_inputs_embeds_size() + self.dtype = vllm_config.model_config.dtype \ + if isinstance(vllm_config.model_config.dtype, torch.dtype) \ + else getattr(torch, vllm_config.model_config.dtype) + + def close(self): + self.wait_for_put() + + if self.put_loop.is_running(): + self.put_loop.call_soon_threadsafe(self.put_loop.stop) + self.put_thread.join() + + self.put_loop.close() + + self.store.unregister_buffer( + self.tensor_pool.base_address, self.config.transfer_buffer_size + ) + self.tensor_pool.cleanup() + + self.store.close() + logger.info("Closed the mooncake store connection") + + def batch_exists(self, keys: list[str]) -> list[bool]: + if not keys: + return [] + return self.store.batch_is_exist(keys) + + def batch_remove(self, keys: list[str]) -> int: + if not keys: + return 0 + pattern = re.compile(r"\b(" + "|".join(re.escape(k) for k in keys) + r")\b") + return self.store.remove_by_regex(pattern) + + def metadata_key(self, key: str) -> str: + # TODO: no guarantee that there is no (k,v) with this key + return key + METADATA_SUFFIX + + def get(self, key: str) -> torch.Tensor | None: + logger.error("Single get operation is not supported. Use batch_get instead.") + raise NotImplementedError( + "Single get is not supported. Use batch_get([key]) instead." + ) + + def batch_get(self, metas: list[MooncakeLoadMeta], device) -> list[torch.Tensor | None]: + if not metas: + return [] + + buffer_shapes = [] + buffer_addrs = None + sizes = [] + for meta in metas: + buffer_shape = (meta.num_token, self.embed_size) + element_size = torch.tensor([], dtype=self.dtype).element_size() + num_elem = math.prod(buffer_shape) + buffer_size = num_elem * element_size + sizes.append(buffer_size) + buffer_shapes.append(buffer_shape) + + with self.pool_lock: + buffer_addrs = [ + self.tensor_pool.allocate(buffer_size) for buffer_size in sizes + ] + + # Fill None first and + # replace valid keys with corresponding buffers + results = [None] * len(metas) + try: + keys = [meta.key for meta in metas] + read_bytes = self.store.batch_get_into(keys, buffer_addrs, sizes) + except Exception as e: + with self.pool_lock: + self.tensor_pool.batch_free(buffer_addrs) + logger.error("batch_get_into failed: %s", str(e)) + return results + + for i in range(len(metas)): + if read_bytes[i] > 0: + results[i] = self.tensor_pool.load_tensor( + buffer_addrs[i], self.dtype, buffer_shapes[i], device + ) + else: + logger.debug("fail to load for key %s", metas[i].key) + + with self.pool_lock: + self.tensor_pool.batch_free(buffer_addrs) + + return results + + def put(self, key: str, tensor: torch.Tensor) -> None: + logger.error("Single put operation is not supported. Use batch_put instead.") + raise NotImplementedError( + "Single put is not supported. Use batch_put([key], [tensor]) instead." + ) + + def wait_for_put(self): + future = asyncio.run_coroutine_threadsafe( + self._wait_for_put_async(), self.put_loop + ) + future.result() # wait until complete + + async def _wait_for_put_async(self): + async with self.put_queue_cv: + while self.put_queue: + await self.put_queue_cv.wait() + + def batch_put(self, keys: list[str], tensors: list[torch.Tensor]) -> None: + self.put_loop.call_soon_threadsafe( + lambda: self.put_loop.create_task(self._batch_put_async(keys, tensors)) + ) + + async def _batch_put_async( + self, keys: list[str], tensors: list[torch.Tensor] + ) -> None: + async with self.put_queue_cv: + self.put_queue.update(keys) + + try: + await self._zero_copy_batch_put(keys, tensors) + finally: + async with self.put_queue_cv: + self.put_queue.difference_update(keys) + if not self.put_queue: + self.put_queue_cv.notify_all() + + async def _zero_copy_batch_put( + self, keys: list[str], tensors: list[torch.Tensor] + ) -> None: + if not keys: + return + + # Allocate buffer + buffer_addrs = [] + buffer_sizes = [] + with self.pool_lock: + for tensor in tensors: + buffer_addr = self.tensor_pool.store_tensor(tensor) + buffer_size = tensor.numel() * tensor.element_size() + buffer_addrs.append(buffer_addr) + buffer_sizes.append(buffer_size) + + try: + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor( + self.io_executor, + self.store.batch_put_from, + keys, + buffer_addrs, + buffer_sizes, + self.replica_config, + ), + timeout=self.config.transfer_timeout, + ) + + # On success, do not free buffer_addrs + # Tensor pool will automatically free for us + buffer_addrs = [] + except asyncio.TimeoutError: + logger.error( + "Timeout while putting keys %s (timeout=%s seconds)", + ",".join(keys), + self.config.transfer_timeout, + ) + except Exception as e: + logger.error( + "Failed to put keys %s using batch_put_from with error %s", + ",".join(keys), + str(e), + ) + finally: + if buffer_addrs: + with self.pool_lock: + self.tensor_pool.batch_free(buffer_addrs) diff --git a/vllm/distributed/ec_transfer/utils/__init__.py b/vllm/distributed/ec_transfer/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/ec_transfer/utils/tensor_memory_pool.py b/vllm/distributed/ec_transfer/utils/tensor_memory_pool.py new file mode 100644 index 000000000000..dda6e8d39f47 --- /dev/null +++ b/vllm/distributed/ec_transfer/utils/tensor_memory_pool.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import atexit +import ctypes +import math +from collections import OrderedDict +from dataclasses import dataclass + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class MemoryBlock: + size: int + addr: int + + +class InsufficientMemoryError(ValueError): + """Raised when there is insufficient memory in the tensor pool.""" + + pass + + +class TensorMemoryPool: + """ + A memory pool for managing pinned host memory allocations for tensors. + + This class implements a buddy allocation system to efficiently manage pinned + host memory for tensor storage. It supports allocation, deallocation, and + tensor storage/retrieval operations. + + Key Features: + - Automatically evict when pool is full using FIFO policy + - Uses power-of-two block sizes for efficient buddy allocation + - Supports splitting and merging of memory blocks + - Provides methods to store CUDA tensors in pinned host memory + - Allows loading tensors from pinned memory back to device + - Automatically cleans up memory on destruction + + Attributes: + max_block_size (int): Maximum block size (rounded to nearest power of two) + min_block_size (int): Minimum block size (rounded to nearest power of two) + free_lists (dict): Dictionary of free memory blocks by size + allocated_blocks (dict): Dictionary of currently allocated blocks + base_tensor (torch.Tensor): Base pinned memory tensor + base_address (int): Base memory address of the pinned memory region + + Example: + >>> pool = TensorMemoryPool(max_block_size=1024 * 1024) + >>> tensor = torch.randn(100, device="cuda") + >>> addr = pool.store_tensor(tensor) + >>> loaded_tensor = pool.load_tensor(addr, tensor.dtype, tensor.shape, "cuda") + >>> pool.free(addr) + + Raises: + ValueError: If block sizes are invalid or max_block_size is less + than min_block_size + """ + + def __init__(self, max_block_size: int, min_block_size: int = 512): + if max_block_size <= 0 or min_block_size <= 0: + raise ValueError("Block sizes must be positive") + if max_block_size < min_block_size: + raise ValueError("Max block size must be greater than min block size") + + self.max_block_size = self._round_to_power_of_two(max_block_size) + self.min_block_size = self._round_to_power_of_two(min_block_size) + + self.free_lists: dict[int, dict[int, MemoryBlock]] = {} + self.allocated_blocks: OrderedDict[int, MemoryBlock] = OrderedDict() + + self._initialize_free_lists() + self._allocate_pinned_memory() + + atexit.register(self.cleanup) + + def _round_to_power_of_two(self, size: int) -> int: + return 1 << (size - 1).bit_length() + + def _initialize_free_lists(self): + size = self.max_block_size + while size >= self.min_block_size: + self.free_lists[size] = {} + size //= 2 + + def _allocate_pinned_memory(self): + self.base_tensor = torch.empty( + self.max_block_size // 4, dtype=torch.float32, pin_memory=True + ) + self.base_address = self.base_tensor.data_ptr() + initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address) + self.free_lists[self.max_block_size][initial_block.addr] = initial_block + + logger.debug( + "TensorMemoryPool, base_address:%d, max_block_size:%d", + self.base_address, + self.max_block_size, + ) + + def _allocate(self, required_size: int) -> int: + current_size = required_size + while current_size <= self.max_block_size: + if self.free_lists[current_size]: + _, block = self.free_lists[current_size].popitem() + self._split_block(block, required_size) + self.allocated_blocks[block.addr] = block + return block.addr + current_size *= 2 + raise InsufficientMemoryError() + + def allocate(self, size: int) -> int: + """Allocates a memory block of at least the requested size. + + Args: + size (int): Minimum size of memory to allocate + + Returns: + int: Address of the allocated memory block + + Raises: + ValueError: If size is invalid or insufficient memory is available + """ + if size <= 0: + raise ValueError("Allocation size must be positive") + + required_size = self._round_to_power_of_two(max(size, self.min_block_size)) + if required_size > self.max_block_size: + raise ValueError("Requested size exceeds maximum block size") + + while True: + try: + return self._allocate(required_size) + except InsufficientMemoryError: + if self.allocated_blocks: + self.free() + else: + raise InsufficientMemoryError( + f"Can not allocate required size {required_size}" + ) from None + + def _split_block(self, block: MemoryBlock, required_size: int): + while block.size > required_size and block.size // 2 >= self.min_block_size: + buddy_size = block.size // 2 + buddy_addr = block.addr + buddy_size + + buddy = MemoryBlock(size=buddy_size, addr=buddy_addr) + block.size = buddy_size + + self.free_lists[buddy_size][buddy.addr] = buddy + + def free(self, addr: int | None = None): + """Frees an allocated memory block. + + Args: + addr (int | None): Address of the block to free. + When it is None, the first block is evicted. + + Raises: + ValueError: If address is invalid or not allocated + """ + if addr is None: + if self.allocated_blocks: + # Retrieved the earliest inserted key + addr = next(iter(self.allocated_blocks)) + else: + raise ValueError("No available block to free") + + if addr not in self.allocated_blocks: + raise ValueError("Invalid address to free") + + block = self.allocated_blocks.pop(addr) + self._merge_buddies(block) + + def batch_free(self, addrs: list[int]): + for addr in addrs: + self.free(addr) + + def _merge_buddies(self, block: MemoryBlock): + MAX_MERGE_DEPTH = 30 + depth = 0 + + while depth < MAX_MERGE_DEPTH: + buddy_offset = ( + block.size + if (block.addr - self.base_address) % (2 * block.size) == 0 + else -block.size + ) + buddy_addr = block.addr + buddy_offset + buddy = self.free_lists[block.size].get(buddy_addr) + if buddy: + del self.free_lists[buddy.size][buddy.addr] + merged_addr = min(block.addr, buddy.addr) + merged_size = block.size * 2 + block = MemoryBlock(size=merged_size, addr=merged_addr) + depth += 1 + else: + break + self.free_lists[block.size][block.addr] = block + + def store_tensor(self, tensor: torch.Tensor) -> int: + """Stores a CUDA tensor in pinned host memory. + + Args: + tensor (torch.Tensor): CUDA tensor to store + + Returns: + int: Address where the tensor is stored + + Raises: + ValueError: If tensor is not on CUDA or allocation fails + """ + if tensor.get_device() == -1: + raise ValueError("Only CUDA tensors can be stored") + + size = tensor.element_size() * tensor.numel() + addr = self.allocate(size) + block = self.allocated_blocks[addr] + + if block.size < size: + self.free(addr) + raise InsufficientMemoryError( + f"Allocated block size {block.size} is smaller than " + f"required size {size}" + ) + + try: + buffer = (ctypes.c_byte * block.size).from_address(block.addr) + cpu_tensor = torch.frombuffer( + buffer, dtype=tensor.dtype, count=tensor.numel() + ).reshape(tensor.shape) + except ValueError as err: + self.free(addr) + raise ValueError(f"Failed to create tensor view: {err}") from err + + cpu_tensor.copy_(tensor) + + return addr + + def load_tensor( + self, addr: int, dtype: torch.dtype, shape: tuple[int, ...], device + ) -> torch.Tensor: + """Loads a tensor from pinned host memory to the specified device. + + Args: + addr (int): Address where tensor is stored + dtype (torch.dtype): Data type of the tensor + shape (tuple[int, ...]): Shape of the tensor + device: Target device for the loaded tensor + + Returns: + torch.Tensor: The loaded tensor on the specified device + + Raises: + ValueError: If address is invalid or sizes don't match + """ + if addr not in self.allocated_blocks: + raise ValueError("Invalid address to load") + + block = self.allocated_blocks[addr] + num_elements = math.prod(shape) + dtype_size = torch.tensor([], dtype=dtype).element_size() + required_size = num_elements * dtype_size + + if required_size > block.size: + raise ValueError("Requested tensor size exceeds block size") + + buffer = (ctypes.c_byte * block.size).from_address(block.addr) + cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape( + shape + ) + + cuda_tensor = torch.empty(shape, dtype=dtype, device=device) + + cuda_tensor.copy_(cpu_tensor) + + return cuda_tensor + + def cleanup(self): + """Cleans up all memory resources and resets the pool state.""" + self.free_lists.clear() + self.allocated_blocks.clear() + if hasattr(self, "base_tensor"): + del self.base_tensor + + def __del__(self): + self.cleanup() diff --git a/vllm/v1/worker/ec_connector_model_runner_mixin.py b/vllm/v1/worker/ec_connector_model_runner_mixin.py index 08a41532ea8e..457798b92698 100644 --- a/vllm/v1/worker/ec_connector_model_runner_mixin.py +++ b/vllm/v1/worker/ec_connector_model_runner_mixin.py @@ -36,6 +36,11 @@ def maybe_save_ec_to_connector( connector = get_ec_transfer() connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash) + @staticmethod + def maybe_wait_for_ec_save() -> None: + if has_ec_transfer(): + get_ec_transfer().wait_for_save() + @staticmethod def get_finished_ec_transfers( scheduler_output: "SchedulerOutput", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22a3f9d8d2dd..2e55bf227c8b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2190,6 +2190,7 @@ def _execute_mm_encoder( logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + self.maybe_wait_for_ec_save() return encoder_outputs def _gather_mm_embeddings(