-
Notifications
You must be signed in to change notification settings - Fork 77
pd: support fp8 kvcache in insert_blocks_to_device #693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pd: support fp8 kvcache in insert_blocks_to_device #693
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for FP8 key-value cache types (float8_e4m3fn and float8_e5m2) in the insert_blocks_to_device method by implementing a workaround for a PyTorch issue. The changes ensure FP8 tensors are temporarily converted to uint8 for indexing operations before being converted back.
Key changes:
- Added FP8 dtype detection and uint8 conversion workaround
- Updated both tuple and non-tuple cache handling paths to support FP8 types
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # WA: https://github.com/pytorch/pytorch/issues/169656 | ||
| view_as_uint = src_cache.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] | ||
| if view_as_uint: | ||
| src_cache = src_cache.view(torch.uint8) |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original dtype information is lost after converting src_cache to uint8. Later references to src_cache.dtype on lines 240 and 245 will return torch.uint8 instead of the original FP8 dtype. Store the original dtype in a variable before the conversion: original_dtype = src_cache.dtype and use original_dtype in the view conversions.
| dst_cache[i].index_copy_(0, dst_block_indices, _src_cache[i].to(dst_cache[i].device)) | ||
| indexed_cache = _src_cache[i] | ||
| if view_as_uint: | ||
| indexed_cache = indexed_cache.view(src_cache.dtype) |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This attempts to view as src_cache.dtype, but src_cache was already converted to uint8 on line 234, so this will view as uint8 again instead of the original FP8 dtype. Use the original dtype stored before the conversion.
| dst_cache.index_copy_(0, dst_block_indices, src_cache[src_block_indices].to(dst_cache.device)) | ||
| indexed_cache = src_cache[src_block_indices] | ||
| if view_as_uint: | ||
| indexed_cache = indexed_cache.view(src_cache.dtype) |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as the tuple path: this views as src_cache.dtype which is now uint8, not the original FP8 dtype. Use the stored original dtype instead.
Signed-off-by: Xinyu Chen <[email protected]>
✅ CI PassedAll checks passed successfully against the following vllm commit: |
| # WA: https://github.com/pytorch/pytorch/issues/169656 | ||
| view_as_uint = src_cache.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] | ||
| if view_as_uint: | ||
| src_cache = src_cache.view(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view as uint8? Can you explain more, how it helps here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index_cpu doesn't support fp8 data type. view as uint8 here only for data movement.
No description provided.