diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh old mode 100644 new mode 100755 diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py index 2b7b14d89b8a..feda96ba7f6f 100644 --- a/vllm/distributed/ec_transfer/ec_connector/base.py +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -159,6 +159,22 @@ def save_caches( """ pass + @abstractmethod + def maybe_update_remote_cache_state( + self, encoder_cache: dict[str, torch.Tensor] + ) -> None: + """ + Maybe update the remote cache state based on the local encoder cache. + + This method can be used to synchronize or update the state of the + remote cache based on changes in the local encoder cache. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + """ + pass + def get_finished( self, finished_req_ids: set[str] ) -> tuple[set[str] | None, set[str] | None]: @@ -199,7 +215,9 @@ def has_caches( pass @abstractmethod - def update_state_after_alloc(self, request: "Request", index: int): + def update_state_after_alloc( + self, request: "Request", index: int, local_hit: bool, remote_hit: bool + ): """ Update ECConnector state to decide allocate cache for requests diff --git a/vllm/distributed/ec_transfer/ec_connector/example_connector.py b/vllm/distributed/ec_transfer/ec_connector/example_connector.py old mode 100644 new mode 100755 index 5f2eff5a8e6a..a1a7697d452f --- a/vllm/distributed/ec_transfer/ec_connector/example_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/example_connector.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import json import os +import weakref from dataclasses import dataclass from typing import TYPE_CHECKING import safetensors +from filelock import FileLock from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( @@ -21,25 +25,29 @@ logger = init_logger(__name__) -@dataclass -class MMMeta: - mm_hash: str - num_token: int - - @staticmethod - def make_meta(mm_hash, num_token) -> "MMMeta": - return MMMeta(mm_hash=mm_hash, num_token=num_token) +def _get_file_lock(path: str) -> FileLock: + lock_path = path + ".lock" + lock = FileLock(lock_path) + return lock @dataclass class ECExampleConnectorMetadata(ECConnectorMetadata): - mm_datas: list[MMMeta] - def __init__(self): - self.mm_datas = [] + self.mm_datas_to_load: list[str] = [] + self.mm_datas_to_save: list[str] = [] + self.mm_datas_to_update: dict[str, int] = {} + + def add_meta_to_load(self, mm_hash: str): + self.mm_datas_to_load.append(mm_hash) - def add_mm_data(self, mm_data: MMMeta): - self.mm_datas.append(mm_data) + def add_meta_to_save(self, mm_hash: str): + self.mm_datas_to_save.append(mm_hash) + + def add_meta_to_update(self, mm_hash: str): + if mm_hash not in self.mm_datas_to_update: + self.mm_datas_to_update[mm_hash] = 0 + self.mm_datas_to_update[mm_hash] += 1 class ECExampleConnector(ECConnectorBase): @@ -48,8 +56,10 @@ class ECExampleConnector(ECConnectorBase): def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): super().__init__(vllm_config=vllm_config, role=role) - # req_id -> index - self._mm_datas_need_loads: dict[str, int] = {} + self._mm_datas_need_loads: set[str] = set() + self._mm_datas_need_saves: set[str] = set() + # list of mm_hash to update meta (read or write depending on role) + self._mm_datas_need_update_meta: list[str] = [] transfer_config = vllm_config.ec_transfer_config if transfer_config is not None: self._storage_path = transfer_config.get_from_extra_config( @@ -60,6 +70,17 @@ def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): else: raise ValueError("ec_transfer_config must be set for ECConnectorBase") + # Default deallocate_cache flag + self._deallocate_cache_enabled = ( + transfer_config.get_from_extra_config("deallocate_cache", False) + if transfer_config + else False + ) + logger.info( + "deallocate_cache enabled is %s", + self._deallocate_cache_enabled, + ) + def start_load_caches(self, encoder_cache, **kwargs) -> None: """ Start loading the cache from the connector into vLLM's encoder cache. @@ -87,13 +108,24 @@ def start_load_caches(self, encoder_cache, **kwargs) -> None: ) return # Load the EC for each mm data - for mm_data in metadata.mm_datas: - if mm_data.mm_hash in encoder_cache: + for mm_hash in metadata.mm_datas_to_load: + if mm_hash in encoder_cache: continue - filename = self._generate_filename_debug(mm_data.mm_hash) - ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda() - encoder_cache[mm_data.mm_hash] = ec_cache - logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash) + + filename = self._generate_filename_debug(mm_hash) + try: + ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda() + encoder_cache[mm_hash] = ec_cache + logger.debug("Success load encoder cache for hash %s", mm_hash) + except Exception as e: + logger.error( + "Failed to load encoder cache for %s: %s", + mm_hash, + str(e), + ) + + if self._deallocate_cache_enabled: + self.update_mm_meta(mm_hash, 1) def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None: """ @@ -111,10 +143,15 @@ def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None: # Return if it is PD Instance if not self.is_producer: return + filename = self._generate_filename_debug(mm_hash) ec_cache = encoder_cache[mm_hash] tensors = {"ec_cache": ec_cache.detach().cpu()} safetensors.torch.save_file(tensors, filename) + + if self._deallocate_cache_enabled: + self.update_mm_meta(mm_hash, 1) + logger.debug("Save cache successful for mm_hash %s", mm_hash) def has_caches( @@ -139,14 +176,35 @@ def update_state_after_alloc( self, request: "Request", index: int, + local_hit: bool, + remote_hit: bool, ) -> None: """ Update ECConnector state after encoder cache allocation. """ mm_hash = request.mm_features[index].identifier - num_encoder_token = request.get_num_encoder_tokens(index) - # Insert mm_hash only if this block has not been recorded yet. - self._mm_datas_need_loads[mm_hash] = num_encoder_token + if remote_hit and not local_hit: + self._mm_datas_need_loads.add(mm_hash) + elif not remote_hit and local_hit: + self._mm_datas_need_saves.add(mm_hash) + elif remote_hit and local_hit: + self._mm_datas_need_update_meta.append(mm_hash) + + def maybe_update_remote_cache_state(self, encoder_cache, **kwargs) -> None: + metadata = self._get_connector_metadata() + assert isinstance(metadata, ECExampleConnectorMetadata) + + for mm_hash in metadata.mm_datas_to_save: + if (not self.is_producer) or (mm_hash not in encoder_cache): + continue + + self.save_caches( + encoder_cache=encoder_cache, + mm_hash=mm_hash, + ) + + for mm_hash, count in metadata.mm_datas_to_update.items(): + self.update_mm_meta(mm_hash, count) def build_connector_meta( self, @@ -161,9 +219,16 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ meta = ECExampleConnectorMetadata() - for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): - meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) + for mm_hash in self._mm_datas_need_loads: + meta.add_meta_to_load(mm_hash) + for mm_hash in self._mm_datas_need_saves: + meta.add_meta_to_save(mm_hash) + for mm_hash in self._mm_datas_need_update_meta: + meta.add_meta_to_update(mm_hash) + self._mm_datas_need_loads.clear() + self._mm_datas_need_saves.clear() + self._mm_datas_need_update_meta.clear() return meta # ============================== @@ -199,3 +264,53 @@ def _generate_filename_debug(self, mm_hash: str) -> str: """ foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created return os.path.join(foldername, "encoder_cache.safetensors") + + def _generate_meta_filename(self, mm_hash: str) -> str: + """ + Return the full path of the metadata JSON file for this mm_hash. + """ + foldername = self._generate_foldername_debug(mm_hash) + return os.path.join(foldername, "meta.json") + + def update_mm_meta(self, mm_hash: str, count: int) -> None: + """ + Create or update the metadata file for the given mm_hash. + Increase read (or write) count by count + when connector is consumer (or producer). + When read count matches write count, cache file is removed. + """ + WRITE_COUNT = "write_count" + READ_COUNT = "read_count" + # No-op when deallocation metadata behavior is disabled. + if not self._deallocate_cache_enabled: + return + + read_count = count if not self.is_producer else 0 + write_count = count if self.is_producer else 0 + + meta_filename = self._generate_meta_filename(mm_hash) + + lock = _get_file_lock(meta_filename) + # Acquire per-file lock before reading/writing metadata + with lock: + if os.path.exists(meta_filename): + # Update existing meta + with open(meta_filename, "r+") as f: + data = json.load(f) + data[WRITE_COUNT] += write_count + data[READ_COUNT] += read_count + + if data[WRITE_COUNT] == data[READ_COUNT] and data[READ_COUNT] > 0: + tensorfile = self._generate_filename_debug(mm_hash) + with contextlib.suppress(FileNotFoundError): + os.remove(tensorfile) + os.remove(meta_filename) + return + else: + data = { + WRITE_COUNT: write_count, + READ_COUNT: read_count, + } + + with open(meta_filename, "w") as f: + json.dump(data, f, indent=4) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py old mode 100644 new mode 100755 index 278970ae7ee8..1b41ed471f3f --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -280,14 +280,15 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. encoder_inputs_to_schedule = None - external_load_encoder_input: list[int] = [] + # List of tuples: (media_index, local_hit, remote_hit) + external_update_encoder_input: list[tuple[int, bool, bool]] = [] new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: ( encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, - external_load_encoder_input, + external_update_encoder_input, ) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, @@ -402,11 +403,14 @@ def schedule(self) -> SchedulerOutput: for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget - if external_load_encoder_input: - for i in external_load_encoder_input: - self.encoder_cache_manager.allocate(request, i) + if external_update_encoder_input: + for i, local_hit, remote_hit in external_update_encoder_input: + if not local_hit: + self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: - self.ec_connector.update_state_after_alloc(request, i) + self.ec_connector.update_state_after_alloc( + request, i, local_hit, remote_hit + ) # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -511,7 +515,7 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None - external_load_encoder_input = [] + external_update_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget if load_kv_async: @@ -547,7 +551,7 @@ def schedule(self) -> SchedulerOutput: encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, - external_load_encoder_input, + external_update_encoder_input, ) = self._try_schedule_encoder_inputs( request, num_computed_tokens, @@ -651,11 +655,14 @@ def schedule(self) -> SchedulerOutput: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget # Allocate for external load encoder cache - if external_load_encoder_input: - for i in external_load_encoder_input: - self.encoder_cache_manager.allocate(request, i) + if external_update_encoder_input: + for i, local_hit, remote_hit in external_update_encoder_input: + if not local_hit: + self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: - self.ec_connector.update_state_after_alloc(request, i) + self.ec_connector.update_state_after_alloc( + request, i, local_hit, remote_hit + ) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) @@ -875,7 +882,7 @@ def _try_schedule_encoder_inputs( num_new_tokens: int, encoder_compute_budget: int, shift_computed_tokens: int = 0, - ) -> tuple[list[int], int, int, list[int]]: + ) -> tuple[list[int], int, int, list[tuple[int, bool, bool]]]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -902,7 +909,9 @@ def _try_schedule_encoder_inputs( mm_features = request.mm_features assert mm_features is not None assert len(mm_features) > 0 - external_load_encoder_input = [] + # List of tuples: (media_index, local_hit, remote_hit) + # for encoder cache state tracking + external_update_encoder_input: list[tuple[int, bool, bool]] = [] # Check remote cache first if self.ec_connector is not None: @@ -957,6 +966,12 @@ def _try_schedule_encoder_inputs( if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached from a # previous step. + # Store to option update remote cache state when hit on local + # encoder cache + if self.ec_connector is not None: + external_update_encoder_input.append( + (i, True, remote_cache_has_item[i]) + ) continue # If no encoder input chunking is allowed, we do not want to @@ -994,7 +1009,9 @@ def _try_schedule_encoder_inputs( if self.ec_connector is not None and remote_cache_has_item[i]: mm_hashes_to_schedule.add(request.mm_features[i].identifier) - external_load_encoder_input.append(i) + external_update_encoder_input.append( + (i, False, remote_cache_has_item[i]) + ) num_tokens_to_schedule += num_encoder_tokens continue @@ -1007,7 +1024,7 @@ def _try_schedule_encoder_inputs( encoder_inputs_to_schedule, num_new_tokens, encoder_compute_budget, - external_load_encoder_input, + external_update_encoder_input, ) def get_grammar_bitmask( diff --git a/vllm/v1/worker/ec_connector_model_runner_mixin.py b/vllm/v1/worker/ec_connector_model_runner_mixin.py old mode 100644 new mode 100755 index 08a41532ea8e..b0ecd49e10b6 --- a/vllm/v1/worker/ec_connector_model_runner_mixin.py +++ b/vllm/v1/worker/ec_connector_model_runner_mixin.py @@ -31,7 +31,6 @@ def maybe_save_ec_to_connector( mm_hash: str, ): if not has_ec_transfer(): - logger.debug("Not have ec transfer please check") return connector = get_ec_transfer() connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash) @@ -83,5 +82,5 @@ def _get_ec_connector_output( output.finished_sending, output.finished_recving = ( ec_connector.get_finished(scheduler_output.finished_req_ids) ) - + ec_connector.maybe_update_remote_cache_state(encoder_cache) ec_connector.clear_connector_metadata()