Skip to content

[Paged KV] Enable prefix caching on the unified paged path#283

Merged
WindChimeRan merged 4 commits intovllm-project:mainfrom
ricky-chaoju:feat/enable-prefix-caching
Apr 20, 2026
Merged

[Paged KV] Enable prefix caching on the unified paged path#283
WindChimeRan merged 4 commits intovllm-project:mainfrom
ricky-chaoju:feat/enable-prefix-caching

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

@ricky-chaoju ricky-chaoju commented Apr 19, 2026

Summary
Removes the platform-layer force-disable that was blockingenable_prefix_caching on the paged path. The unified prefill code already handles num_computed_tokens > 0 (#195, #207, #208, #211); the only remaining gap was that platform.py:278-285 overrode it.

Adds an end-to-end correctness test that fires identical prompts twice through vllm.LLM(enable_prefix_caching=True) — the second pass walks the start_pos > 0 path; tokens still match the cache-off golden.

Both test classes share a single LLM fixture: Metal memory held by a released LLM is not freed by Python gc, so a second module-scope LLM would hit kv_budget=0.

Hybrid models

Upstream ModelConfig.is_prefix_caching_supported already returns False for hybrid/Mamba models, so the default_prefix_caching resolution in vllm/engine/arg_utils.py keeps cache off unless the user explicitly forces it. No vllm-metal-side guard needed.

Benchmark

benchmark_182

Closes #182.

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should put the test of Paged Prefix Cache into separated files.

Comment thread vllm_metal/platform.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will turn on prefix caching as default for user. Could you please run a quick benchmark comparing with and without prefix caching? We can check TTFT, throughput, and cache hit rate, to see if it actually works. For the dataset, concatenating a shared system prompt should be fine.

Copy link
Copy Markdown
Contributor Author

@ricky-chaoju ricky-chaoju Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark with vllm bench serve --dataset-name prefix_repetition on Qwen3-0.6B (paged attention). Workload: 1 shared prefix of 512 tokens, 50 prompts each with 128-token suffix and 100-token output; 10 repeats per side, median ± pstdev reported.
benchmark_182

Metric Cache off (med ± sd) Cache on (med ± sd) Δ
Throughput (tok/s) 210.59 ± 8.93 288.88 ± 12.23 1.37×
TTFT mean (ms) 8667.88 ± 588.32 6071.39 ± 554.66 1.43×
TTFT P50 (ms) 8622.69 ± 645.68 5894.11 ± 451.23 1.46×
TTFT P99 (ms) 16308.74 ± 1049.18 11523.16 ± 943.63 1.42×
TPOT P50 (ms) 30.74 ± 1.63 22.88 ± 0.77 1.34×
TPOT P99 (ms) 119.51 ± 51.95 25.62 ± 8.99 4.67×
E2EL P50 (ms) 11574.40 ± 664.36 8141.27 ± 652.99 1.42×

Cache hit rate (from /metrics):

  • Cold (run 1): 78.4% (25088 / 32000)
  • Warm (runs 2-10, each): 97.5% (31200 / 32000)
  • Cumulative across 10 cache-on runs: 95.6% (305888 / 320000)

baseline-* runs report 0 queries / 0 hits (cache disabled).

(On TPOT P99: cache-on stable across 10 runs (std ±9 ms, range 23-50). Baseline noisy (std ±52 ms, range 36-197). Single-sweep ratio floats ~3.5×-5.5× because of baseline noise, not cache-side flakiness. The distribution-based 4.67× is the robust number.)

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No cache hit assert

Suggest asserting on scheduler cache-hit counters or runner._request_states block reuse. Even a single assert computed_tokens > 0 on at least one req

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@WindChimeRan
Copy link
Copy Markdown
Collaborator

Upstream ModelConfig.is_prefix_caching_supported already returns False for hybrid/Mamba models, so the default_prefix_caching resolution in vllm/engine/arg_utils.py keeps cache off unless the user explicitly forces it. No vllm-metal-side guard needed.

Please add a followup patch for the doc #284

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
@ricky-chaoju
Copy link
Copy Markdown
Contributor Author

Upstream ModelConfig.is_prefix_caching_supported already returns False for hybrid/Mamba models, so the default_prefix_caching resolution in vllm/engine/arg_utils.py keeps cache off unless the user explicitly forces it. No vllm-metal-side guard needed.

Please add a followup patch for the doc #284

Will follow up once #284 lands.

@WindChimeRan WindChimeRan merged commit 36f5ba8 into vllm-project:main Apr 20, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RoadMap] [Paged KV] Prefix Caching on Paged KV Cache Path

2 participants