Skip to content

Comments

feat: add prompt prefix caching to SimpleEngine#90

Open
panbanda wants to merge 1 commit intowaybarrios:mainfrom
panbanda:feat/prompt-prefix-caching
Open

feat: add prompt prefix caching to SimpleEngine#90
panbanda wants to merge 1 commit intowaybarrios:mainfrom
panbanda:feat/prompt-prefix-caching

Conversation

@panbanda
Copy link

@panbanda panbanda commented Feb 15, 2026

Summary

  • Adds KV prefix caching to SimpleEngine.stream_generate() using mlx-lm's LRUPromptCache (trie-based prefix matching with LRU eviction)
  • In agentic loops (e.g. Claude Code), successive requests share 90-99% of their prompt prefix, making prefill the bottleneck. This brings prompt eval from ~27s to ~1.8s on a 28k-token prompt (15-27x speedup on M4 Max)
  • Strips x-anthropic- tracking headers from system blocks in the Anthropic adapter, since they contain per-request hashes that break prefix matching
  • Configurable via prompt_cache_size parameter (default 10, 0 to disable)

Implementation details

vllm_mlx/models/llm.py: Accept optional prompt_cache parameter in stream_generate() and pass it through to mlx_lm.stream_generate(). Also accept list[int] (token IDs) as prompt type, not just str.

vllm_mlx/engine/simple.py:

  • Tokenize prompt and look up nearest prefix in trie cache via LRUPromptCache.fetch_nearest_cache()
  • Deep copy cache before generation (mlx-lm's generate_step() mutates the cache in place), with try/except fallback if deepcopy fails on MLX arrays
  • Pass only the remaining (non-cached) tokens to the model along with the cached KV state
  • After generation, store prompt-only KV state (excludes generated tokens, since the next request's representation of generated content may differ due to tool call formatting, system reminders, etc.)
  • Cache is cleared on stop() to prevent stale KV state if the engine is recycled
  • prompt_cache_size is configurable (default 10, 0 to disable caching entirely)

vllm_mlx/api/anthropic_adapter.py: Skip system content blocks starting with x-anthropic- (e.g. x-anthropic-billing-header: cc_version=...; cch=HASH). The cch= hash changes every request, causing prefix divergence at token ~33 and defeating the entire cache.

tests/test_simple_engine_prefix_cache.py: 7 new tests covering cache miss on first request, disabled cache (size=0), stop() cleanup, configurable size, billing header stripping, and edge cases.

Test results (Claude Code agentic loop)

Round Tokens reused Prompt eval time
1 (cold) 0/28457 (0%) ~27s
2 28457/29485 (96%) ~1.8s
3 29485/29650 (99%) ~0.7s
4 29650/29738 (99%) ~1.4s

Test plan

  • 7 new unit tests (all passing)
  • Existing 50 tests unaffected (all passing)
  • Verify cold start (first request) works normally with 0% cache hit
  • Verify second request to same conversation shows high cache hit rate
  • Verify growing conversations maintain high reuse (96%+)
  • Verify prompt_cache_size=0 disables caching cleanly
  • Verify stop() clears stale cache state

Generated with Claude Code

@panbanda panbanda force-pushed the feat/prompt-prefix-caching branch 2 times, most recently from 07c4c7d to aeb9614 Compare February 15, 2026 22:57
@waybarrios waybarrios self-requested a review February 16, 2026 01:38
@waybarrios waybarrios self-assigned this Feb 16, 2026
@waybarrios waybarrios added the enhancement New feature or request label Feb 16, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@janhilgard janhilgard force-pushed the feat/prompt-prefix-caching branch from aeb9614 to e92178f Compare February 16, 2026 08:33
@waybarrios
Copy link
Owner

Really nice work on this. The 15-27x speedup numbers are impressive, and using mlx-lm's LRUPromptCache with trie-based prefix matching is a solid approach for agentic loops where prompts share most of their prefix. The test coverage is thorough too, especially the edge cases around cache disabled and stop clearing.

I noticed a few things that might need attention:

  1. The cache snapshot is taken before generation runs (cache_snapshot = copy.deepcopy(cache) at line ~227), then stored under the full token key after generation finishes. At that point the snapshot contains only the pre-prefill state (empty on miss, or partial prefix on hit), not the full prompt's KV state. On the next request with the same tokens, fetch_nearest_cache returns this incomplete snapshot and remaining ends up empty, so the model generates from wrong KV state. Commit b191aec fixed this same pattern for BatchedEngine by capturing cache state after prompt processing, not before.

  2. Related to the above: copy.deepcopy(cache) on MLX KVCache objects doesn't copy the underlying Metal buffers. The Python-level copy succeeds but both objects share the same Metal memory. When generation mutates the cache in-place, the snapshot is corrupted. Commit 3f8b006 explicitly removed deepcopy from PrefixCacheManager for this reason, noting "MLX arrays are immutable" (copy-on-write semantics make deepcopy unnecessary but also misleading).

  3. LRUPromptCache.fetch_nearest_cache(model, tokens) expects a model key for cache namespace separation. The PR passes self._model_name (a string). The mlx-lm server uses a tuple (model_path, adapter_path, draft_model_path) as the key. A simple string works as a hashable key but breaks can_trim_prompt_cache semantics and doesn't match the established pattern in BatchedEngine (commit b191aec).

  4. Small one: the x-anthropic- header stripping in anthropic_adapter.py uses text.startswith("x-anthropic-") which could also silently drop user-authored system content that happens to start with that prefix. A more specific match like text.startswith("x-anthropic-billing-header:") or a header-line regex would be safer.

Issues 1 and 2 together might mean cache hits don't return the right KV state. Happy to help think through the snapshotting approach if that would be useful. The overall design direction is great though, this will be a big win for interactive use cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants