Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions vllm_gaudi/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,22 @@ def insert_blocks_to_device(
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from src_cache to dst_cache on HPU."""
# WA: https://github.com/pytorch/pytorch/issues/169656
view_as_uint = src_cache.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
if view_as_uint:
src_cache = src_cache.view(torch.uint8)
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original dtype information is lost after converting src_cache to uint8. Later references to src_cache.dtype on lines 240 and 245 will return torch.uint8 instead of the original FP8 dtype. Store the original dtype in a variable before the conversion: original_dtype = src_cache.dtype and use original_dtype in the view conversions.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view as uint8? Can you explain more, how it helps here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_cpu doesn't support fp8 data type. view as uint8 here only for data movement.

if isinstance(dst_cache, tuple):
_src_cache = src_cache[:, src_block_indices]
for i in range(len(dst_cache)):
dst_cache[i].index_copy_(0, dst_block_indices, _src_cache[i].to(dst_cache[i].device))
indexed_cache = _src_cache[i]
if view_as_uint:
indexed_cache = indexed_cache.view(src_cache.dtype)
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attempts to view as src_cache.dtype, but src_cache was already converted to uint8 on line 234, so this will view as uint8 again instead of the original FP8 dtype. Use the original dtype stored before the conversion.

Copilot uses AI. Check for mistakes.
dst_cache[i].index_copy_(0, dst_block_indices, indexed_cache.to(dst_cache[i].device))
else:
dst_cache.index_copy_(0, dst_block_indices, src_cache[src_block_indices].to(dst_cache.device))
indexed_cache = src_cache[src_block_indices]
if view_as_uint:
indexed_cache = indexed_cache.view(src_cache.dtype)
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the tuple path: this views as src_cache.dtype which is now uint8, not the original FP8 dtype. Use the stored original dtype instead.

Copilot uses AI. Check for mistakes.
dst_cache.index_copy_(0, dst_block_indices, indexed_cache.to(dst_cache.device))
torch.hpu.synchronize()

@classmethod
Expand Down