diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 46f0cd9289b2..724ce367a298 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -51,6 +51,7 @@ class BlockStored(KVCacheEvent): block_size: int lora_id: Optional[int] medium: Optional[str] + mm_hashes: list[list[str]] class BlockRemoved(KVCacheEvent): diff --git a/vllm/outputs.py b/vllm/outputs.py index 4d8206bb2d83..e37e98ecde90 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -116,6 +116,7 @@ def __init__( *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, kv_transfer_params: Optional[dict[str, Any]] = None, + mm_hashes: Optional[list[str]] = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -136,6 +137,7 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens self.kv_transfer_params = kv_transfer_params + self.mm_hashes = mm_hashes def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 3cc738304821..db2508a9c2cd 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -252,7 +252,10 @@ def cache_full_blocks( lora_id=request.lora_request.id if request.lora_request else None, medium=MEDIUM_GPU, - )) + mm_hashes=self._get_block_mm_hash(request, + num_cached_blocks, + num_full_blocks, + block_size))) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -414,3 +417,18 @@ def take_events(self) -> list[KVCacheEvent]: events = self.kv_event_queue self.kv_event_queue = [] return events + + def _get_block_mm_hash(self, request: Request, num_cached_blocks: int, + num_full_blocks: int, block_size: int): + block_mm_hashes: list[list[str]] = [[] + for _ in range(num_full_blocks - + num_cached_blocks)] + start_token_idx = num_cached_blocks * block_size + end_token_idx = num_full_blocks + block_size + for mm_feature in request.mm_features: + if (mm_feature.mm_position.offset >= start_token_idx + and mm_feature.mm_position.offset < end_token_idx): + block_mm_hashes[(mm_feature.mm_position.offset - + start_token_idx) // block_size].append( + mm_feature.identifier) + return block_mm_hashes diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 38b2d6824b47..2e36ba2e2f0d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -98,6 +98,7 @@ def __init__( top_p: Optional[float] = None, n: Optional[int] = None, temperature: Optional[float] = None, + mm_hashes=None, ): self.request_id = request_id self.parent_req = parent_req @@ -118,6 +119,7 @@ def __init__( self.is_prefilling = True self.queue = queue self.num_cached_tokens = 0 + self.mm_hashes = mm_hashes self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None @@ -132,6 +134,7 @@ def from_new_request( request_index: int, queue: Optional[RequestOutputCollector], log_stats: bool, + mm_hashes: list[str], ) -> "RequestState": if sampling_params := request.sampling_params: @@ -179,6 +182,7 @@ def from_new_request( arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, + mm_hashes=mm_hashes, ) def make_request_output( @@ -257,6 +261,7 @@ def _new_request_output( finished=finished, kv_transfer_params=kv_transfer_params, num_cached_tokens=self.num_cached_tokens, + mm_hashes=self.mm_hashes, ) def _new_completion_output( @@ -365,13 +370,15 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") + mm_hashes = [f.identifier for f in request.mm_features] req_state = RequestState.from_new_request(tokenizer=self.tokenizer, request=request, prompt=prompt, parent_req=parent_req, request_index=request_index, queue=queue, - log_stats=self.log_stats) + log_stats=self.log_stats, + mm_hashes=mm_hashes) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: