Skip to content

feat: implement grammar/structured output support for Metal paged path#280

Merged
LxYuan0420 merged 9 commits intovllm-project:mainfrom
sts07142:feat/238-grammar-structured-output
Apr 20, 2026
Merged

feat: implement grammar/structured output support for Metal paged path#280
LxYuan0420 merged 9 commits intovllm-project:mainfrom
sts07142:feat/238-grammar-structured-output

Conversation

@sts07142
Copy link
Copy Markdown
Contributor

implement grammar/structured output support for Metal paged path

Fixes: #238


Test

# PyTest
pytest tests/test_grammar_bitmask.py
# Reproduce
python reproduce_grammar_238.py
reproduce_grammar_238.py
# SPDX-License-Identifier: Apache-2.0
"""Reproduction script for issue #238 — grammar/structured output on Metal paged path.

Goes through the real sample_tokens() → _sample_paged_batch() path so it
demonstrates the bug on `main` (grammar ignored, unconstrained token sampled)
and the fix on our branch (only the allowed token sampled).

Usage:
    source .venv-vllm-metal/bin/activate
    python reproduce_grammar_238.py
"""

import math
import sys
from types import SimpleNamespace

import mlx.core as mx
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.sampler import Sampler

import vllm_metal.v1.model_runner as mr
from vllm_metal.v1.model_runner import (
    RequestState,
    _ExecutionBatch,
    _PagedForwardState,
)

VOCAB = 128
# Token 90 is '{' in most LLaMA-family vocabularies — matches the issue evidence:
#   grammar_bitmask = [[0, 0, 67108864, ...]]
#   word 2, bit 26  →  90 // 32 == 2,  1 << (90 % 32) == 67108864
ALLOWED_TOKEN = 90


def _single_token_bitmask(token_id: int, vocab: int) -> np.ndarray:
    words = math.ceil(vocab / 32)
    bitmask = np.zeros((1, words), dtype=np.int32)
    bitmask[0, token_id // 32] = 1 << (token_id % 32)
    return bitmask


def run() -> None:
    # ---- build a minimal runner stub ----
    runner = mr.MetalModelRunner.__new__(mr.MetalModelRunner)
    runner._sampler = Sampler()
    runner.device = torch.device("cpu")
    runner._vocab_size = VOCAB
    runner._logitsprocs = None
    runner._request_states = {}
    runner._paged_request_seq_lens = {}
    runner._pending_output = None
    runner._execute_model_state = None
    # Patch methods that touch scheduler/paged state we haven't set up
    runner._validate_scheduled_outputs = lambda batch, sched: None
    runner._cleanup_finished_requests = lambda req_ids: None

    # ---- one decode request, greedy sampling ----
    req_state = RequestState(
        token_ids=[1],
        prompt_len=1,
        cache=[],
        sampling_params=SamplingParams(temperature=0.0),
        generator=None,
        generated_tokens=0,
    )
    decode_reqs = [("r0", req_state)]

    # ---- uniform logits: without bitmask argmax == token 0 ----
    logits = mx.zeros((1, 1, VOCAB))

    # ---- scheduler_output stub ----
    scheduler_output = SimpleNamespace(
        scheduled_spec_decode_tokens={},
        num_scheduled_tokens={"r0": 1},
        total_num_scheduled_tokens=1,
        finished_req_ids=set(),
    )

    # ---- grammar: allow ONLY token ALLOWED_TOKEN ----
    grammar_output = SimpleNamespace(
        structured_output_request_ids=["r0"],
        grammar_bitmask=_single_token_bitmask(ALLOWED_TOKEN, VOCAB),
    )

    # ---- wire up _execute_model_state (paged forward already "done") ----
    batch = _ExecutionBatch()
    # _sample_paged_batch writes output via batch.paged_decode_reqs, not
    # decode_reqs directly — this mirrors what _collect_cached_requests does.
    batch.paged_decode_reqs = decode_reqs

    runner._execute_model_state = _PagedForwardState(
        batch=batch,
        prefill_reqs=[],
        decode_reqs=decode_reqs,
        scheduler_output=scheduler_output,
        logits=logits,
        cu_seqlens=[0, 1],
        num_decode=1,
    )

    # ---- call the real sample_tokens path ----
    output = runner.sample_tokens(grammar_output=grammar_output)

    sampled = output.sampled_token_ids[0][0]

    print(f"Vocab size      : {VOCAB}")
    print(f"Allowed token   : {ALLOWED_TOKEN}")
    print(f"Sampled token   : {sampled}")
    print()

    if sampled == ALLOWED_TOKEN:
        print("✓ PASS — grammar bitmask applied correctly")
    else:
        print(f"✗ FAIL — expected token {ALLOWED_TOKEN}, got {sampled}")
        sys.exit(1)


if __name__ == "__main__":
    run()

Test Result

PyTest

===================================================================== test session starts =====================================================================
platform darwin -- Python 3.12.13, pytest-9.0.3, pluggy-1.6.0
rootdir: /Users/name/Personal/vllm-metal
configfile: pyproject.toml
plugins: asyncio-1.3.0, anyio-4.13.0
asyncio: mode=Mode.AUTO, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 30 items

tests/test_grammar_bitmask.py ..............................                                                                                            [100%]

====================================================================== warnings summary =======================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

.venv-vllm-metal/lib/python3.12/site-packages/torch/jit/_script.py:362: 14 warnings
  /Users/name/Personal/vllm-metal/.venv-vllm-metal/lib/python3.12/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================== 30 passed, 16 warnings in 3.45s ===============================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Reproduce

# Before
INFO 04-17 00:24:36 [__init__.py:44] Available plugins for group vllm.platform_plugins:
INFO 04-17 00:24:36 [__init__.py:46] - metal -> vllm_metal:register
INFO 04-17 00:24:36 [__init__.py:49] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 04-17 00:24:37 [__init__.py:239] Platform plugin metal is activated
INFO 04-17 00:24:38 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
Vocab size      : 128
Allowed token   : 90
Sampled token   : 0

✗ FAIL — expected token 90, got 0
# After
INFO 04-17 00:23:46 [__init__.py:44] Available plugins for group vllm.platform_plugins:
INFO 04-17 00:23:46 [__init__.py:46] - metal -> vllm_metal:register
INFO 04-17 00:23:46 [__init__.py:49] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 04-17 00:23:47 [__init__.py:239] Platform plugin metal is activated
INFO 04-17 00:23:48 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
Vocab size      : 128
Allowed token   : 90
Sampled token   : 90

✓ PASS — grammar bitmask applied correctly

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Comment thread vllm_metal/v1/model_runner.py Outdated
Comment thread vllm_metal/v1/model_runner.py Outdated
Comment thread vllm_metal/v1/model_runner.py Outdated
Comment thread tests/test_grammar_bitmask.py Outdated
Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work overall; a few changes are still needed before merging this.
The main blockers are:

  1. the non-paged guard is too late and should fail fast in execute_model()
  2. the grammar/xgrammar logic should be extracted out of model_runner.py into a dedicated structured-output owner/module; ownership/design problem

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
@ericcurtin
Copy link
Copy Markdown
Collaborator

_apply_grammar_bitmask_metal is dead code (non-paged path raises NotImplementedError), minor torch_to_mlx aliasing comment

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
…ontract

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
@sts07142 sts07142 requested a review from LxYuan0420 April 17, 2026 14:02
Comment thread vllm_metal/v1/structured_output.py
Comment thread vllm_metal/v1/structured_output.py Outdated
Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this to operate only on the sampled rows, then share at least one paged structured-output verification run with actual output content and brief evidence that the masking path no longer copies the full (total_tokens, vocab) plane.

…d tests

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
@sts07142
Copy link
Copy Markdown
Contributor Author

code
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
"""Visual verification: apply_paged transfers only the constrained rows to CPU.

Usage:
    .venv-vllm-metal/bin/python scripts/verify_structured_output_cpu_transfer.py
"""

import math
from types import SimpleNamespace
from unittest.mock import patch

import mlx.core as mx
import numpy as np

import vllm_metal.v1.structured_output as so
from vllm_metal.v1.structured_output import MetalStructuredOutputApplier

VOCAB = 32_000


def bitmask(tok: int, vocab: int = VOCAB) -> np.ndarray:
    words = math.ceil(vocab / 32)
    bm = np.zeros((1, words), dtype=np.int32)
    bm[0, tok // 32] = np.int32(1 << (tok % 32))
    return bm


class _Spy:
    """Wraps numpy and records every 2-D vocab-width array that np.array()
    materialises — those are the only MLX-to-CPU logit transfers."""

    def __init__(self) -> None:
        self.shapes: list[tuple[int, ...]] = []

    def __getattr__(self, name: str):
        return getattr(np, name)

    def array(self, x, *args, **kwargs):
        arr = np.array(x, *args, **kwargs)
        if isinstance(arr, np.ndarray) and arr.ndim == 2 and arr.shape[-1] == VOCAB:
            self.shapes.append(arr.shape)
        return arr


def run(total: int, constrained_ids: list[str], allowed_tokens: list[int]) -> None:
    decode_reqs = [(f"r{i}", SimpleNamespace()) for i in range(total)]
    grammar = SimpleNamespace(
        structured_output_request_ids=constrained_ids,
        grammar_bitmask=np.vstack([bitmask(t) for t in allowed_tokens]),
    )
    logits = mx.zeros((1, total, VOCAB))
    cu = list(range(total + 1))
    sched = SimpleNamespace(scheduled_spec_decode_tokens={})

    spy = _Spy()
    with patch.object(so, "np", spy):
        result = MetalStructuredOutputApplier().apply_paged(
            sched, grammar, decode_reqs, [], cu, total, logits
        )
    r = np.array(result)

    print(f"\n{'─' * 55}")
    print(f"  total_tokens = {total:>4}   constrained = {len(constrained_ids)}")
    print(f"{'─' * 55}")

    for rid, tok in zip(constrained_ids, allowed_tokens):
        row = int(rid[1:])
        finite = np.where(np.isfinite(r[0, row]))[0].tolist()
        ok = finite == [tok]
        print(f"  row {row:>3} ({rid})  allowed={tok:<5}  finite tokens: {finite}  {'✓' if ok else '✗'}")

    unconstrained = [i for i in range(total) if f"r{i}" not in constrained_ids][:3]
    for row in unconstrained:
        all_fin = bool(np.all(np.isfinite(r[0, row])))
        print(f"  row {row:>3} (unconstrained)           all finite: {all_fin}  {'✓' if all_fin else '✗'}")

    print()
    if spy.shapes:
        s = spy.shapes[0]
        full = total * VOCAB
        actual = s[0] * s[1]
        saved = 100 * (full - actual) / full
        print(f"  CPU transfer : np.array() shape = {s}")
        print(f"  floats sent  : {actual:>10,}  (= {s[0]} rows × {s[1]:,} vocab)")
        print(f"  full plane   : {full:>10,}  (= {total} rows × {VOCAB:,} vocab)")
        print(f"  saved        : {full - actual:>10,}  ({saved:.1f}%)")
    print()


if __name__ == "__main__":
    run(total=128, constrained_ids=["r0", "r1"],         allowed_tokens=[42, 1337])
    run(total=64,  constrained_ids=["r0"],               allowed_tokens=[7])
    run(total=32,  constrained_ids=["r0", "r1", "r2"],   allowed_tokens=[100, 200, 300])

───────────────────────────────────────────────────────
total_tokens = 128 constrained = 2
───────────────────────────────────────────────────────
row 0 (r0) allowed=42 finite tokens: [42] ✓
row 1 (r1) allowed=1337 finite tokens: [1337] ✓
row 2 (unconstrained) all finite: True ✓
row 3 (unconstrained) all finite: True ✓
row 4 (unconstrained) all finite: True ✓

CPU transfer : np.array() shape = (2, 32000)
floats sent : 64,000 (= 2 rows × 32,000 vocab)
full plane : 4,096,000 (= 128 rows × 32,000 vocab)
saved : 4,032,000 (98.4%)

───────────────────────────────────────────────────────
total_tokens = 64 constrained = 1
───────────────────────────────────────────────────────
row 0 (r0) allowed=7 finite tokens: [7] ✓
row 1 (unconstrained) all finite: True ✓
row 2 (unconstrained) all finite: True ✓
row 3 (unconstrained) all finite: True ✓

CPU transfer : np.array() shape = (1, 32000)
floats sent : 32,000 (= 1 rows × 32,000 vocab)
full plane : 2,048,000 (= 64 rows × 32,000 vocab)
saved : 2,016,000 (98.4%)

───────────────────────────────────────────────────────
total_tokens = 32 constrained = 3
───────────────────────────────────────────────────────
row 0 (r0) allowed=100 finite tokens: [100] ✓
row 1 (r1) allowed=200 finite tokens: [200] ✓
row 2 (r2) allowed=300 finite tokens: [300] ✓
row 3 (unconstrained) all finite: True ✓
row 4 (unconstrained) all finite: True ✓
row 5 (unconstrained) all finite: True ✓

CPU transfer : np.array() shape = (3, 32000)
floats sent : 96,000 (= 3 rows × 32,000 vocab)
full plane : 1,024,000 (= 32 rows × 32,000 vocab)
saved : 928,000 (90.6%)

@ericcurtin
Copy link
Copy Markdown
Collaborator

Build failed and:

Defensive assert could be RuntimeError

…en guard

- Convert defensive assert to raise RuntimeError
- Apply ruff format

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
@sts07142
Copy link
Copy Markdown
Contributor Author

@ericcurtin I fixed it.

@sts07142 sts07142 requested a review from LxYuan0420 April 20, 2026 03:20
Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I verified end to end on this branch with Qwen/Qwen3-0.6B and sent a /v1/chat/completions request with response_format.type=json_schema, and got valid schema-constrained JSON back.

@LxYuan0420 LxYuan0420 merged commit fca88ea into vllm-project:main Apr 20, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] v1 Engine Metal Backend Missing Grammar/Structured Output Support - Bitmask Ignored in sample_tokens

3 participants