Skip to content

[Performance] Eliminate unnecessary H2D copies in FlashInfer decode #21854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

MatthewBonanni
Copy link
Contributor

@MatthewBonanni MatthewBonanni commented Jul 29, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

In the FlashInfer backend, BatchDecodeWithPagedKVCacheWrapper::plan() copies tensors to device, which is unnecessary when device copies already exist. This PR refactors to eliminate these copies.

Test Plan

Correctness:

pytest tests/v1/attention/test_attention_backends.py

Performance:

vllm bench throughput \
        --model "NousResearch/Hermes-3-Llama-3.1-8B" \
        --dataset-name "random" \
        --input-len 128 \
        --output-len 512 \
        --num-prompts 100

Test Result

Correctness passes. FlashInferMetadataBuilder::_plan() duration reduced from 150us to 100us.

(Optional) Documentation Update

…of indptr and last_page_len

Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
@mergify mergify bot added the v1 label Jul 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the FlashInfer backend by eliminating unnecessary host-to-device memory copies, leading to a performance improvement. The approach of pre-creating tensors on the device and passing them down is sound.

The main concern is the introduction of the fast_decode_plan function, which duplicates code from the flashinfer library, posing a future maintenance risk. I've provided a detailed comment with suggestions to improve the maintainability of this new function.

Comment on lines 719 to 720
def fast_decode_plan(
self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new function fast_decode_plan appears to be a modified copy of flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.plan. While this is effective for the performance gain, it introduces a maintenance risk by duplicating code from an external library. If the upstream flashinfer library updates its plan method, this code may become outdated.

The ideal long-term solution would be to contribute this change upstream to flashinfer.

For the current implementation, I have a few suggestions to improve clarity and maintainability:

  1. Function Signature: The first argument is named self, which is idiomatic for instance methods but confusing for a standalone function. Renaming it to wrapper would clarify that it expects an instance of BatchDecodeWithPagedKVCacheWrapper.
  2. Add Explanatory Comment: A comment at the beginning of the function explaining why it's a copy of the upstream code and what changes were made would be very helpful for future maintainers. For example:
    # This function is a modified version of flashinfer's
    # BatchDecodeWithPagedKVCacheWrapper.plan method. It's been refactored
    # to accept pre-allocated device tensors (indptr, last_page_len) to
    # avoid unnecessary H2D copies within the planning phase.
    # TODO(author): Consider upstreaming this change to flashinfer.
  3. Visibility: Since this function is only used within this module, consider renaming it to _fast_decode_plan to indicate it's a private helper function.
Suggested change
def fast_decode_plan(
self,
def _fast_decode_plan(
wrapper,

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Matthew Bonanni <[email protected]>
@MatthewBonanni MatthewBonanni force-pushed the feature/flashinfer_fast_decode_plan branch from 5317449 to 9fcd3d2 Compare July 29, 2025 17:46
@MatthewBonanni MatthewBonanni changed the title [Performance] Eliminate unnecessary H2D copies in FlashInfer backend [Performance] Eliminate unnecessary H2D copies in FlashInfer decode Jul 29, 2025
Comment on lines 449 to 466
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
dim=0, dtype=torch.int32)

paged_kv_indptr = torch.zeros(len(block_table_bounds) + 1,
dtype=torch.int32,
device=self.device)
paged_kv_indptr[1:] = block_table_bounds.cumsum(dim=0,
dtype=torch.int32)

paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
paged_kv_last_page_len_cpu = torch.where(
paged_kv_last_page_len_cpu == 0, page_size,
paged_kv_last_page_len_cpu)

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson am I correct that the device versions of these tensors must be computed as I've done here? Or are they available elsewhere? I'm showing that the cost of this compute negates the benefit of eliminating the copies

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya we would need to compute these on device, however that should be non-blocking on the CPU side (the bottleneck); are you seeing alot of CPU overhead in the trace?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the where is taking long time

image

we might be able to play some tricks here like doing (I think this correct?!?):

paged_kv_last_page_len = (seq_lens % -page_size) + page_size

instead of

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
                                    page_size, paged_kv_last_page_len)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, this and your other suggestion about cumsum() helped!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have any new numbers?

Signed-off-by: Matthew Bonanni <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants