Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tests/test_paged_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,17 @@ def _set_env():
def vllm_outputs():
"""Run vLLM offline inference once for all prompts.

Uses max_num_seqs=1 to avoid batch-invariance non-determinism on Metal.
Pinned to ``enable_prefix_caching=False`` so the golden token IDs
(cache-off reference) remain the invariant under test regardless of
upstream default changes.
"""
llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=1)
llm = LLM(
model=MODEL_NAME,
max_model_len=512,
max_num_seqs=1,
enable_prefix_caching=False,
)

# Verify paged KV path is active when requested
if os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION", "0") == "1":
runner = llm.llm_engine.model_executor.driver_worker.model_runner
assert runner._paged_attention_backend is not None, (
Expand Down
104 changes: 104 additions & 0 deletions tests/test_paged_prefix_caching_e2e.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No cache hit assert

Suggest asserting on scheduler cache-hit counters or runner._request_states block reuse. Even a single assert computed_tokens > 0 on at least one req

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
"""End-to-end correctness of paged prefix caching (issue #182).

Fires the deterministic-test prompts twice through ``vllm.LLM`` with
prefix caching enabled. The first pass primes the cache; the second
pass exercises the model_runner's ``start_pos > 0`` path because the
upstream scheduler reports ``num_computed_tokens > 0``. The asserted
token sequence is the existing cache-off golden, so a broken cache-hit
path surfaces as a token mismatch.

The LLM body runs in a spawned child process (``multiprocessing`` with
the ``spawn`` start method) so Metal device init happens in a fresh
interpreter. This is required on the Metal platform because:
- ``fork`` inherits the parent's Metal context and segfaults in the
child (Metal is not fork-safe).
- Running in the parent pytest process alongside the cache-off
baseline fixture in ``test_paged_deterministic`` causes
``kv_budget=0`` — MLX wired buffers aren't released by Python gc.
"""

from __future__ import annotations

import multiprocessing as mp
import os

import pytest

from tests.test_paged_deterministic import (
DEFAULT_PAGED_MEMORY_FRACTION,
DEFAULT_USE_PAGED_ATTENTION,
)


def _setenv_default(key: str, default: str) -> None:
if os.environ.get(key) is None:
os.environ[key] = default


def _run_prefix_cache_correctness() -> None:
"""Body of the e2e test — runs in a spawned child process.

Imports happen lazily inside the child so vllm / MLX init is not
inherited from the parent process.
"""
_setenv_default("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
_setenv_default("VLLM_METAL_USE_PAGED_ATTENTION", DEFAULT_USE_PAGED_ATTENTION)
_setenv_default("VLLM_METAL_MEMORY_FRACTION", DEFAULT_PAGED_MEMORY_FRACTION)

if os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION", "0") != "1":
return # non-paged path: nothing to test

from vllm import LLM, SamplingParams

from tests.test_paged_deterministic import (
GOLDEN_MLX,
GOLDEN_PAGED,
MAX_TOKENS,
MODEL_NAME,
PROMPTS,
)

llm = LLM(
model=MODEL_NAME,
max_model_len=512,
max_num_seqs=1,
enable_prefix_caching=True,
)
sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
llm.generate(PROMPTS, sp) # prime the cache
outputs = llm.generate(PROMPTS, sp) # cache hits expected
by_prompt = {o.prompt: o for o in outputs}

mismatches = []
for prompt in PROMPTS:
output = by_prompt[prompt]
token_ids = list(output.outputs[0].token_ids)
mlx_expected = GOLDEN_MLX[prompt]
paged_expected = GOLDEN_PAGED[prompt]
if token_ids != mlx_expected and token_ids != paged_expected:
mismatches.append(
f" {prompt!r}\n"
f" got: {token_ids}\n"
f" mlx golden: {mlx_expected}\n"
f" pgd golden: {paged_expected}"
)

if mismatches:
raise AssertionError(
"Prefix-cached output matched neither golden set for some prompts:\n"
+ "\n".join(mismatches)
)


@pytest.mark.slow
def test_prefix_cached_matches_golden() -> None:
ctx = mp.get_context("spawn")
proc = ctx.Process(target=_run_prefix_cache_correctness)
proc.start()
proc.join()
if proc.exitcode != 0:
raise AssertionError(
f"Prefix-cache e2e test failed in spawned child "
f"(exit code: {proc.exitcode})"
)
14 changes: 8 additions & 6 deletions tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,18 @@ python -m tools.benchmark.attention_benchmark --mode varlen --q-lens 1,4,16,64 -

## Prefix Caching Benchmark

Measures TTFT with shared-prefix workloads using `prefix_repetition` dataset.
Establishes a baseline before prefix caching is implemented (#159).
Measures TTFT / TPOT / E2EL with shared-prefix workloads using the
upstream `prefix_repetition` dataset. Compare cache-off baseline vs
cache-on by toggling `--enable-prefix-caching` / `--no-enable-prefix-caching`.

**1. Start the server:**

```bash
# Adjust MEMORY_FRACTION based on available RAM (lower if OOM).
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.7 \
vllm serve Qwen/Qwen3-0.6B \
--port 8000 --max-model-len 2048 --max-num-seqs 8
--port 8000 --max-model-len 2048 --max-num-seqs 8 \
--enable-prefix-caching
```

**2. Run the benchmark:**
Expand All @@ -93,8 +95,8 @@ vllm bench serve \
--request-rate inf \
--percentile-metrics ttft,tpot,e2el \
--metric-percentiles 50,99 \
--save-result --label baseline
--save-result --label cache-on
```

Key metric is **TTFT** — with prefix caching enabled, requests sharing
the same prefix should show lower TTFT on cache hits.
For a cache-off baseline, restart the server with
`--no-enable-prefix-caching` and re-run with `--label baseline`.
9 changes: 0 additions & 9 deletions vllm_metal/platform.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will turn on prefix caching as default for user. Could you please run a quick benchmark comparing with and without prefix caching? We can check TTFT, throughput, and cache hit rate, to see if it actually works. For the dataset, concatenating a shared system prompt should be fine.

Copy link
Copy Markdown
Collaborator Author

@ricky-chaoju ricky-chaoju Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark with vllm bench serve --dataset-name prefix_repetition on Qwen3-0.6B (paged attention). Workload: 1 shared prefix of 512 tokens, 50 prompts each with 128-token suffix and 100-token output; 10 repeats per side, median ± pstdev reported.
benchmark_182

Metric Cache off (med ± sd) Cache on (med ± sd) Δ
Throughput (tok/s) 210.59 ± 8.93 288.88 ± 12.23 1.37×
TTFT mean (ms) 8667.88 ± 588.32 6071.39 ± 554.66 1.43×
TTFT P50 (ms) 8622.69 ± 645.68 5894.11 ± 451.23 1.46×
TTFT P99 (ms) 16308.74 ± 1049.18 11523.16 ± 943.63 1.42×
TPOT P50 (ms) 30.74 ± 1.63 22.88 ± 0.77 1.34×
TPOT P99 (ms) 119.51 ± 51.95 25.62 ± 8.99 4.67×
E2EL P50 (ms) 11574.40 ± 664.36 8141.27 ± 652.99 1.42×

Cache hit rate (from /metrics):

  • Cold (run 1): 78.4% (25088 / 32000)
  • Warm (runs 2-10, each): 97.5% (31200 / 32000)
  • Cumulative across 10 cache-on runs: 95.6% (305888 / 320000)

baseline-* runs report 0 queries / 0 hits (cache disabled).

(On TPOT P99: cache-on stable across 10 runs (std ±9 ms, range 23-50). Baseline noisy (std ±52 ms, range 36-197). Single-sweep ratio floats ~3.5×-5.5× because of baseline noise, not cache-side flakiness. The distribution-based 4.67× is the robust number.)

Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
scheduler_config.max_num_batched_tokens,
)

if config.use_paged_attention and getattr(
cache_config, "enable_prefix_caching", False
):
# The unified paged path does not yet safely support vLLM core
# prefix-cache hits for new requests. Disable the feature at the
# platform layer until that path is fully supported.
cache_config.enable_prefix_caching = False
logger.info("Metal: disabled prefix caching")

# Configure cache — ensure block_size is at least the Metal kernel
# minimum. With chunked prefill enabled, upstream may default to
# block_size=1 for fine-grained scheduling, but our Metal paged
Expand Down
Loading