-
-
Notifications
You must be signed in to change notification settings - Fork 9k
[Performance] Eliminate unnecessary H2D copies in FlashInfer decode #21854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Performance] Eliminate unnecessary H2D copies in FlashInfer decode #21854
Conversation
…of indptr and last_page_len Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request optimizes the FlashInfer backend by eliminating unnecessary host-to-device memory copies, leading to a performance improvement. The approach of pre-creating tensors on the device and passing them down is sound.
The main concern is the introduction of the fast_decode_plan
function, which duplicates code from the flashinfer
library, posing a future maintenance risk. I've provided a detailed comment with suggestions to improve the maintainability of this new function.
def fast_decode_plan( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new function fast_decode_plan
appears to be a modified copy of flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.plan
. While this is effective for the performance gain, it introduces a maintenance risk by duplicating code from an external library. If the upstream flashinfer
library updates its plan
method, this code may become outdated.
The ideal long-term solution would be to contribute this change upstream to flashinfer
.
For the current implementation, I have a few suggestions to improve clarity and maintainability:
- Function Signature: The first argument is named
self
, which is idiomatic for instance methods but confusing for a standalone function. Renaming it towrapper
would clarify that it expects an instance ofBatchDecodeWithPagedKVCacheWrapper
. - Add Explanatory Comment: A comment at the beginning of the function explaining why it's a copy of the upstream code and what changes were made would be very helpful for future maintainers. For example:
# This function is a modified version of flashinfer's # BatchDecodeWithPagedKVCacheWrapper.plan method. It's been refactored # to accept pre-allocated device tensors (indptr, last_page_len) to # avoid unnecessary H2D copies within the planning phase. # TODO(author): Consider upstreaming this change to flashinfer.
- Visibility: Since this function is only used within this module, consider renaming it to
_fast_decode_plan
to indicate it's a private helper function.
def fast_decode_plan( | |
self, | |
def _fast_decode_plan( | |
wrapper, |
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Matthew Bonanni <[email protected]>
5317449
to
9fcd3d2
Compare
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum( | ||
dim=0, dtype=torch.int32) | ||
|
||
paged_kv_indptr = torch.zeros(len(block_table_bounds) + 1, | ||
dtype=torch.int32, | ||
device=self.device) | ||
paged_kv_indptr[1:] = block_table_bounds.cumsum(dim=0, | ||
dtype=torch.int32) | ||
|
||
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size | ||
paged_kv_last_page_len_cpu = torch.where( | ||
paged_kv_last_page_len_cpu == 0, page_size, | ||
paged_kv_last_page_len_cpu) | ||
|
||
paged_kv_last_page_len = seq_lens % page_size | ||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, | ||
page_size, paged_kv_last_page_len) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson am I correct that the device versions of these tensors must be computed as I've done here? Or are they available elsewhere? I'm showing that the cost of this compute negates the benefit of eliminating the copies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya we would need to compute these on device, however that should be non-blocking on the CPU side (the bottleneck); are you seeing alot of CPU overhead in the trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the where is taking long time

we might be able to play some tricks here like doing (I think this correct?!?):
paged_kv_last_page_len = (seq_lens % -page_size) + page_size
instead of
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, this and your other suggestion about cumsum() helped!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have any new numbers?
Signed-off-by: Matthew Bonanni <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
In the FlashInfer backend,
BatchDecodeWithPagedKVCacheWrapper::plan()
copies tensors to device, which is unnecessary when device copies already exist. This PR refactors to eliminate these copies.Test Plan
Correctness:
pytest tests/v1/attention/test_attention_backends.py
Performance:
Test Result
Correctness passes.
FlashInferMetadataBuilder::_plan()
duration reduced from 150us to 100us.(Optional) Documentation Update