From 859e9aaf83d73121225f20c5a08db183349e6c1f Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Fri, 5 Dec 2025 17:06:42 +0800 Subject: [PATCH] pd: support fp8 kvcache in insert_blocks_to_device Signed-off-by: Xinyu Chen --- vllm_gaudi/platform.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/platform.py b/vllm_gaudi/platform.py index d05dce253..851c8f04f 100644 --- a/vllm_gaudi/platform.py +++ b/vllm_gaudi/platform.py @@ -228,12 +228,22 @@ def insert_blocks_to_device( dst_block_indices: torch.Tensor, ) -> None: """Copy blocks from src_cache to dst_cache on HPU.""" + # WA: https://github.com/pytorch/pytorch/issues/169656 + view_as_uint = src_cache.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + if view_as_uint: + src_cache = src_cache.view(torch.uint8) if isinstance(dst_cache, tuple): _src_cache = src_cache[:, src_block_indices] for i in range(len(dst_cache)): - dst_cache[i].index_copy_(0, dst_block_indices, _src_cache[i].to(dst_cache[i].device)) + indexed_cache = _src_cache[i] + if view_as_uint: + indexed_cache = indexed_cache.view(src_cache.dtype) + dst_cache[i].index_copy_(0, dst_block_indices, indexed_cache.to(dst_cache[i].device)) else: - dst_cache.index_copy_(0, dst_block_indices, src_cache[src_block_indices].to(dst_cache.device)) + indexed_cache = src_cache[src_block_indices] + if view_as_uint: + indexed_cache = indexed_cache.view(src_cache.dtype) + dst_cache.index_copy_(0, dst_block_indices, indexed_cache.to(dst_cache.device)) torch.hpu.synchronize() @classmethod