Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
20 changes: 19 additions & 1 deletion vllm/distributed/ec_transfer/ec_connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,22 @@ def save_caches(
"""
pass

@abstractmethod
def maybe_update_remote_cache_state(
self, encoder_cache: dict[str, torch.Tensor]
) -> None:
"""
Maybe update the remote cache state based on the local encoder cache.

This method can be used to synchronize or update the state of the
remote cache based on changes in the local encoder cache.

Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
"""
pass

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
Expand Down Expand Up @@ -199,7 +215,9 @@ def has_caches(
pass

@abstractmethod
def update_state_after_alloc(self, request: "Request", index: int):
def update_state_after_alloc(
self, request: "Request", index: int, local_hit: bool, remote_hit: bool
):
"""
Update ECConnector state to decide allocate cache for requests

Expand Down
167 changes: 141 additions & 26 deletions vllm/distributed/ec_transfer/ec_connector/example_connector.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import json
import os
import weakref
from dataclasses import dataclass
from typing import TYPE_CHECKING

import safetensors
from filelock import FileLock

from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
Expand All @@ -21,25 +25,29 @@
logger = init_logger(__name__)


@dataclass
class MMMeta:
mm_hash: str
num_token: int

@staticmethod
def make_meta(mm_hash, num_token) -> "MMMeta":
return MMMeta(mm_hash=mm_hash, num_token=num_token)
def _get_file_lock(path: str) -> FileLock:
lock_path = path + ".lock"
lock = FileLock(lock_path)
return lock


@dataclass
class ECExampleConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta]

def __init__(self):
self.mm_datas = []
self.mm_datas_to_load: list[str] = []
self.mm_datas_to_save: list[str] = []
self.mm_datas_to_update: dict[str, int] = {}

def add_meta_to_load(self, mm_hash: str):
self.mm_datas_to_load.append(mm_hash)

def add_mm_data(self, mm_data: MMMeta):
self.mm_datas.append(mm_data)
def add_meta_to_save(self, mm_hash: str):
self.mm_datas_to_save.append(mm_hash)

def add_meta_to_update(self, mm_hash: str):
if mm_hash not in self.mm_datas_to_update:
self.mm_datas_to_update[mm_hash] = 0
self.mm_datas_to_update[mm_hash] += 1


class ECExampleConnector(ECConnectorBase):
Expand All @@ -48,8 +56,10 @@ class ECExampleConnector(ECConnectorBase):

def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
# req_id -> index
self._mm_datas_need_loads: dict[str, int] = {}
self._mm_datas_need_loads: set[str] = set()
self._mm_datas_need_saves: set[str] = set()
# list of mm_hash to update meta (read or write depending on role)
self._mm_datas_need_update_meta: list[str] = []
transfer_config = vllm_config.ec_transfer_config
if transfer_config is not None:
self._storage_path = transfer_config.get_from_extra_config(
Expand All @@ -60,6 +70,17 @@ def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")

# Default deallocate_cache flag
self._deallocate_cache_enabled = (
transfer_config.get_from_extra_config("deallocate_cache", False)
if transfer_config
else False
)
logger.info(
"deallocate_cache enabled is %s",
self._deallocate_cache_enabled,
)

def start_load_caches(self, encoder_cache, **kwargs) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
Expand Down Expand Up @@ -87,13 +108,24 @@ def start_load_caches(self, encoder_cache, **kwargs) -> None:
)
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
for mm_hash in metadata.mm_datas_to_load:
if mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)

filename = self._generate_filename_debug(mm_hash)
try:
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
encoder_cache[mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_hash)
except Exception as e:
logger.error(
"Failed to load encoder cache for %s: %s",
mm_hash,
str(e),
)

if self._deallocate_cache_enabled:
self.update_mm_meta(mm_hash, 1)

def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
"""
Expand All @@ -111,10 +143,15 @@ def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
# Return if it is PD Instance
if not self.is_producer:
return

filename = self._generate_filename_debug(mm_hash)
ec_cache = encoder_cache[mm_hash]
tensors = {"ec_cache": ec_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)

if self._deallocate_cache_enabled:
self.update_mm_meta(mm_hash, 1)

logger.debug("Save cache successful for mm_hash %s", mm_hash)

def has_caches(
Expand All @@ -139,14 +176,35 @@ def update_state_after_alloc(
self,
request: "Request",
index: int,
local_hit: bool,
remote_hit: bool,
) -> None:
"""
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token
if remote_hit and not local_hit:
self._mm_datas_need_loads.add(mm_hash)
elif not remote_hit and local_hit:
self._mm_datas_need_saves.add(mm_hash)
elif remote_hit and local_hit:
self._mm_datas_need_update_meta.append(mm_hash)

def maybe_update_remote_cache_state(self, encoder_cache, **kwargs) -> None:
metadata = self._get_connector_metadata()
assert isinstance(metadata, ECExampleConnectorMetadata)

for mm_hash in metadata.mm_datas_to_save:
if (not self.is_producer) or (mm_hash not in encoder_cache):
continue

self.save_caches(
encoder_cache=encoder_cache,
mm_hash=mm_hash,
)

for mm_hash, count in metadata.mm_datas_to_update.items():
self.update_mm_meta(mm_hash, count)

def build_connector_meta(
self,
Expand All @@ -161,9 +219,16 @@ def build_connector_meta(
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ECExampleConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
for mm_hash in self._mm_datas_need_loads:
meta.add_meta_to_load(mm_hash)
for mm_hash in self._mm_datas_need_saves:
meta.add_meta_to_save(mm_hash)
for mm_hash in self._mm_datas_need_update_meta:
meta.add_meta_to_update(mm_hash)

self._mm_datas_need_loads.clear()
self._mm_datas_need_saves.clear()
self._mm_datas_need_update_meta.clear()
return meta

# ==============================
Expand Down Expand Up @@ -199,3 +264,53 @@ def _generate_filename_debug(self, mm_hash: str) -> str:
"""
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
return os.path.join(foldername, "encoder_cache.safetensors")

def _generate_meta_filename(self, mm_hash: str) -> str:
"""
Return the full path of the metadata JSON file for this mm_hash.
"""
foldername = self._generate_foldername_debug(mm_hash)
return os.path.join(foldername, "meta.json")

def update_mm_meta(self, mm_hash: str, count: int) -> None:
"""
Create or update the metadata file for the given mm_hash.
Increase read (or write) count by count
when connector is consumer (or producer).
When read count matches write count, cache file is removed.
"""
WRITE_COUNT = "write_count"
READ_COUNT = "read_count"
# No-op when deallocation metadata behavior is disabled.
if not self._deallocate_cache_enabled:
return

read_count = count if not self.is_producer else 0
write_count = count if self.is_producer else 0

meta_filename = self._generate_meta_filename(mm_hash)

lock = _get_file_lock(meta_filename)
# Acquire per-file lock before reading/writing metadata
with lock:
if os.path.exists(meta_filename):
# Update existing meta
with open(meta_filename, "r+") as f:
data = json.load(f)
data[WRITE_COUNT] += write_count
data[READ_COUNT] += read_count

if data[WRITE_COUNT] == data[READ_COUNT] and data[READ_COUNT] > 0:
tensorfile = self._generate_filename_debug(mm_hash)
with contextlib.suppress(FileNotFoundError):
os.remove(tensorfile)
os.remove(meta_filename)
return
else:
data = {
WRITE_COUNT: write_count,
READ_COUNT: read_count,
}

with open(meta_filename, "w") as f:
json.dump(data, f, indent=4)
Loading
Loading