Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/distributed/kv_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class BlockStored(KVCacheEvent):
block_size: int
lora_id: Optional[int]
medium: Optional[str]
mm_hashes: list[list[str]]


class BlockRemoved(KVCacheEvent):
Expand Down
2 changes: 2 additions & 0 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
mm_hashes: Optional[list[str]] = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
Expand All @@ -136,6 +137,7 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
self.mm_hashes = mm_hashes

def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
Expand Down
20 changes: 19 additions & 1 deletion vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ def cache_full_blocks(
lora_id=request.lora_request.id
if request.lora_request else None,
medium=MEDIUM_GPU,
))
mm_hashes=self._get_block_mm_hash(request,
num_cached_blocks,
num_full_blocks,
block_size)))

def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.
Expand Down Expand Up @@ -414,3 +417,18 @@ def take_events(self) -> list[KVCacheEvent]:
events = self.kv_event_queue
self.kv_event_queue = []
return events

def _get_block_mm_hash(self, request: Request, num_cached_blocks: int,
num_full_blocks: int, block_size: int):
block_mm_hashes: list[list[str]] = [[]
for _ in range(num_full_blocks -
num_cached_blocks)]
start_token_idx = num_cached_blocks * block_size
end_token_idx = num_full_blocks + block_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in the calculation of end_token_idx. It should be num_full_blocks * block_size to correctly represent the end token index of the full blocks. The current calculation num_full_blocks + block_size mixes block count with block size, which will lead to an incorrect token range and cause multi-modal features to be associated with the wrong blocks.

Suggested change
end_token_idx = num_full_blocks + block_size
end_token_idx = num_full_blocks * block_size

for mm_feature in request.mm_features:
if (mm_feature.mm_position.offset >= start_token_idx
and mm_feature.mm_position.offset < end_token_idx):
block_mm_hashes[(mm_feature.mm_position.offset -
start_token_idx) // block_size].append(
mm_feature.identifier)
return block_mm_hashes
9 changes: 8 additions & 1 deletion vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
top_p: Optional[float] = None,
n: Optional[int] = None,
temperature: Optional[float] = None,
mm_hashes=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint for the mm_hashes parameter is missing. For consistency with other parameters and to improve code clarity and maintainability, please add the appropriate type hint.

Suggested change
mm_hashes=None,
mm_hashes: Optional[list[str]] = None,

):
self.request_id = request_id
self.parent_req = parent_req
Expand All @@ -118,6 +119,7 @@
self.is_prefilling = True
self.queue = queue
self.num_cached_tokens = 0
self.mm_hashes = mm_hashes

self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
Expand All @@ -132,6 +134,7 @@
request_index: int,
queue: Optional[RequestOutputCollector],
log_stats: bool,
mm_hashes: list[str],
) -> "RequestState":

if sampling_params := request.sampling_params:
Expand Down Expand Up @@ -179,6 +182,7 @@
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
mm_hashes=mm_hashes,
)

def make_request_output(
Expand Down Expand Up @@ -257,6 +261,7 @@
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=self.num_cached_tokens,
mm_hashes=self.mm_hashes,
)

def _new_completion_output(
Expand Down Expand Up @@ -365,13 +370,15 @@
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")

mm_hashes = [f.identifier for f in request.mm_features]

Check failure on line 373 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr]

Check failure on line 373 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr]

Check failure on line 373 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr]

Check failure on line 373 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[list[Any]]" has no attribute "__iter__" (not iterable) [union-attr]
req_state = RequestState.from_new_request(tokenizer=self.tokenizer,
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
log_stats=self.log_stats,
mm_hashes=mm_hashes)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
Expand Down
Loading