Skip to content

Add TurboQuant KV cache compression for MHA#261

Merged
WindChimeRan merged 1 commit intovllm-project:mainfrom
UndercoverMathGuy:main
Apr 18, 2026
Merged

Add TurboQuant KV cache compression for MHA#261
WindChimeRan merged 1 commit intovllm-project:mainfrom
UndercoverMathGuy:main

Conversation

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor

@UndercoverMathGuy UndercoverMathGuy commented Apr 12, 2026

What

Adds opt-in quantized KV cache for MHA paged attention, controlled by two env vars:

VLLM_METAL_TURBOQUANT=1 # enables TQ
--turboquant-k-quant # key quant type (q8_0 / q5_0 / q4_0 / uint2 / int8 / uint8 / etc.)
--turboquant-v-quant # value quant type (q8_0 / q5_0 / q4_0 / q3_0 / q2_0)

  • Key quantization — block-wise scalar quant (configurable: 2–8 bit). WHT-based grouping, per-block scale stored alongside the packed data.
  • Value quantization — Lloyd-Max quantization from dynamically computed centroids (range: 2-8bit)
  • Metal dequant kernel — single-pass V lookup + K unpack fused into the paged attention decode path. No separate dequant pass.
  • Block size accounting — get_cache_block_size_bytes reports the correct compressed size so the engine allocates proportionally more blocks
  • TurboQuant is MHA and hybrid-only (is_mla=False). MLA models will throw a NotImplementedError if used with TurboQuant

Why

Unlocks free context length increases with smarter KV cache compression on memory-constrained Apple Silicon (e.g. 8 GB M1 MBA).

Metric bf16 K-q8_0; V-3bit TQ
Max Context Length (Llama 3.2 1B, M1 8GB MBA) 38,768 tok 99,280 tok
Compression 1.0x 2.56x
Tok/s 19.2 tok/s 16.8 tok/s (Small model - compute bound not memory bound)

3.7x theoretical compression of KV cache (q4_0 quantization vs. bf16)

Quantization quality

Measured on random bf16 tensors (head_dim=128):

Quantisation Bits Cos Sim.
q8_0 8 0.999989
q5_0 5 0.999206
q4_0 4 0.996761
uint2 2 0.926900

Overhead

Single-token encode + cache write on Qwen3-0.6B shape (4 KV heads, hd=128):

  • TQ encode (Python): ~610 µs
  • TQ cache write: ~360 µs (vs 340 µs fp16 — near-zero write overhead)
  • Net TQ overhead: ~640 µs ≈ 4% of a 60 tok/s decode budget

Future work

  • Move TurboQuantEncode to a Metal kernel to eliminate the ~610 µs Python encode overhead

@UndercoverMathGuy UndercoverMathGuy force-pushed the main branch 2 times, most recently from 73a42af to fe52c01 Compare April 12, 2026 06:30
Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • deep coupling with pagedattention.metal. Suggest to mv tq to a separated kernel file.
  • this pr introduce env var. Is it possible to use --kv-cache-dtype instead?
  • turboquant.py uses import logging directly instead of from vllm.logger import init_logger as the rest of the codebase

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Thanks for the review - will fix

@UndercoverMathGuy UndercoverMathGuy force-pushed the main branch 2 times, most recently from 497bf6c to dcd5cf1 Compare April 12, 2026 12:39
@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Fixed the concerns @WindChimeRan - moved TQ logic to turboquant.metal; added --turboquant-k-quant and --turboquant-v-quant (V is uint3 only so far) args; fixed the logger import and lint errors

@ericcurtin
Copy link
Copy Markdown
Collaborator

CLI arg name mismatch (--turboquant-k-quant vs --turboquant-k-dtype), wrong env var in serve benchmark (VLLM_METAL_K_QUANT → VLLM_METAL_TURBOQUANT_K_DTYPE), help text typo (uint3 → q3_0)

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Thanks @ericcurtin - fixed all the typos

Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env var should be for platform, not for quantization method.

Please consider the following method, to remove tq env var:

Replace MetalConfig.from_env() with a from_vllm_config(vllm_config) classmethod that reads vllm_config.additional_config:

# In MetalPlatform.check_and_update_config, after existing logic:
add = vllm_config.additional_config or {}
if add.get("turboquant"):
    cfg = get_config()
    cfg.turboquant = True
    cfg.k_quant = add.get("k_quant", "q8_0")
    cfg.v_quant = add.get("v_quant", "q3_0")
    cfg._validate_turboquant()  # reuse the existing validators

Invocation:

vllm serve Llama-3.2-1B \
    --additional-config '{"turboquant": true, "k_quant": "q4_0", "v_quant": "q3_0"}'

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

@WindChimeRan new arg method added - uses additional config JSON to activate turboquant and its params

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@UndercoverMathGuy test failed. Could you please investigate why?

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Test fixes - added getattr to guard against no argument
Merge - needs refactor for the new WorkerCache Planner - WIP

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Modified worker.py for the new cache API and found bug with the GPU KV cache allocation which was still using fp16 for calculation (fixed)

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

@WindChimeRan tests should pass now. worked locally for me

@ricky-chaoju
Copy link
Copy Markdown
Contributor

@UndercoverMathGuy Need to sign to pass DCO

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

My DCO errors seem to be clear now - @ricky-chaoju

@ricky-chaoju
Copy link
Copy Markdown
Contributor

My DCO errors seem to be clear now - @ricky-chaoju

DCO check is still failing, you'll need to fix it on your side

Details:
https://github.com/vllm-project/vllm-metal/pull/261/checks?check_run_id=71772953461

Please refer to the contributing guide:
https://docs.vllm.ai/projects/vllm-metal/en/latest/CONTRIBUTING/

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

@ricky-chaoju - I've signed every single one of my own commits but the DCO is erroring cuz of a weird unexpected email error:

Commit sha: [e813597](https://github.com/vllm-project/vllm-metal/pull/261/commits/e813597ae13fa25de883de7f478b9c4f51650930), Author: Lik Xun Yuan (Lx), Committer: GitHub; Expected "Lik Xun Yuan (Lx) [lxyuan0420@gmail.com](mailto:lxyuan0420@gmail.com)", but got "Yuan Lik Xun [lxyuan0420@gmail.com](mailto:lxyuan0420@gmail.com)".
Commit sha: [5e82935](https://github.com/vllm-project/vllm-metal/pull/261/commits/5e8293517148143a89de6240afbe7d8ab21cd1ad), Author: Lik Xun Yuan (Lx), Committer: GitHub; Expected "Lik Xun Yuan (Lx) [lxyuan0420@gmail.com](mailto:lxyuan0420@gmail.com)", but got "Yuan Lik Xun [lxyuan0420@gmail.com](mailto:lxyuan0420@gmail.com)".
Commit sha: [d911293](https://github.com/vllm-project/vllm-metal/pull/261/commits/d91129356ac60b1033d9171aa7b26fc1b1041bed), Author: Ranran, Committer: GitHub; Expected "Ranran [hzz5361@psu.edu](mailto:hzz5361@psu.edu)", but got "ran [hzz5361@psu.edu](mailto:hzz5361@psu.edu)".
Commit sha: [3592a8c](https://github.com/vllm-project/vllm-metal/pull/261/commits/3592a8c99730a5e66132df08fafe90cddbfc2a58), Author: Ranran, Committer: GitHub; Expected "Ranran [hzz5361@psu.edu](mailto:hzz5361@psu.edu)", but got "ran [hzz5361@psu.edu](mailto:hzz5361@psu.edu)".

@ricky-chaoju
Copy link
Copy Markdown
Contributor

@UndercoverMathGuy The DCO failures are from merge commits rather than your own commits. Could you try rebasing onto main instead of merging?

git fetch upstream
git rebase upstream/main
git push --force-with-lease origin main

That should leave only your own (correctly signed) commits in the PR and clear DCO.

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Yep - @ricky-chaoju - thanks so much for the help - DCO green now

@ericcurtin
Copy link
Copy Markdown
Collaborator

Critical: YOCO shared_kv is not None guard removed from non-turboquant path (silent cache corruption for Gemma-3/YOCO models); hardcoded RNG-key/Metal-sign table coupling; duplicated TQ byte calculation; head_size_v integer division may be wrong

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

@ericcurtin - suggestions fixed thanks - YOCO guard added, hardcoded Metal sign table and RNG validation test (haven't made it dynamic), made TQ byte calculation a function, made a subclass of FullAttentionSpec for Turboquant and used custom byte calculation for KV cache allocation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

Non-blocking: the YOCO shared_kv guard sits behind the if kv_cache.turboquant branch, so it's unreachable when TQ is active. Moving the shared_kv is not None check before the TQ branch preserves the guard for all paths.

@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

@ricky-chaoju - fixed - thanks

Comment thread vllm_metal/metal_kernel_backend/attention_sdpa.py
Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-metal/mkdocs.yaml

Lines 68 to 69 in e115dda

- Features:
- Speech-to-Text: stt.md

Please add docs under the feature tab

@WindChimeRan
Copy link
Copy Markdown
Collaborator

Non-blocking: This PR put the quant decision inside the attention forward pass rather than at the backend-selection boundary.

We'll need a refactoring on attention backend dispatching. TQ should be a different backend in a different file. But this refactoring will introduce a lot of changes. Will do it in future PR.

Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: upstream vllm described a issue with TQ: vllm-project/vllm#38280

I'm mostly interested in

    1. Hybrid models (Qwen3.5) not supported --> Do we have a fail loud for this?
    1. Minimum 4-bit quantization --> "3-bit and 2-bit quantization produce garbage output. " Is this true in our case?

These are non-blocking for this PR. But if we indeed have the same issue, we need to doc them somewhere (2 --> a new issue, 3 --> documentation).

Signed-off-by: Ruhaan Rajadhyaksha <ruhaanr@gmail.com>
@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

UndercoverMathGuy commented Apr 18, 2026

@WindChimeRan - fixed head size bug in attention; added documentation for TQ (turboquant.md, added to mkdocs.yaml); the minimum K quant of 4-bit is true here too - any quantization below 4bits causes nonsense from the model - either infinite loops or unrelated ramblings - added this to documentation.

On this implementation of turboquant, we are able to support hybrid models (Qwen 3.5 - tested and verified) as the current hybrid model attention framework is just a wrapper over the uncorrelated GDN and normal SDPA frameworks (patch function checks list and directs model to correct attention type) - hence, we can simply activate TQ on the SDPA side without any allocation or other incompatibilities with GDN as they are completely separate.

@WindChimeRan WindChimeRan merged commit ec35365 into vllm-project:main Apr 18, 2026
5 checks passed
@UndercoverMathGuy
Copy link
Copy Markdown
Contributor Author

Thanks @WindChimeRan - would you please close the issue #217??

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants