feat: implement grammar/structured output support for Metal paged path#280
Conversation
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
LxYuan0420
left a comment
There was a problem hiding this comment.
Good work overall; a few changes are still needed before merging this.
The main blockers are:
- the non-paged guard is too late and should fail fast in
execute_model() - the grammar/xgrammar logic should be extracted out of
model_runner.pyinto a dedicated structured-output owner/module; ownership/design problem
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
|
_apply_grammar_bitmask_metal is dead code (non-paged path raises NotImplementedError), minor torch_to_mlx aliasing comment |
Signed-off-by: Injae Ryou <injaeryou@gmail.com>
…ontract Signed-off-by: Injae Ryou <injaeryou@gmail.com>
LxYuan0420
left a comment
There was a problem hiding this comment.
Please fix this to operate only on the sampled rows, then share at least one paged structured-output verification run with actual output content and brief evidence that the masking path no longer copies the full (total_tokens, vocab) plane.
…d tests Signed-off-by: Injae Ryou <injaeryou@gmail.com>
code#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
"""Visual verification: apply_paged transfers only the constrained rows to CPU.
Usage:
.venv-vllm-metal/bin/python scripts/verify_structured_output_cpu_transfer.py
"""
import math
from types import SimpleNamespace
from unittest.mock import patch
import mlx.core as mx
import numpy as np
import vllm_metal.v1.structured_output as so
from vllm_metal.v1.structured_output import MetalStructuredOutputApplier
VOCAB = 32_000
def bitmask(tok: int, vocab: int = VOCAB) -> np.ndarray:
words = math.ceil(vocab / 32)
bm = np.zeros((1, words), dtype=np.int32)
bm[0, tok // 32] = np.int32(1 << (tok % 32))
return bm
class _Spy:
"""Wraps numpy and records every 2-D vocab-width array that np.array()
materialises — those are the only MLX-to-CPU logit transfers."""
def __init__(self) -> None:
self.shapes: list[tuple[int, ...]] = []
def __getattr__(self, name: str):
return getattr(np, name)
def array(self, x, *args, **kwargs):
arr = np.array(x, *args, **kwargs)
if isinstance(arr, np.ndarray) and arr.ndim == 2 and arr.shape[-1] == VOCAB:
self.shapes.append(arr.shape)
return arr
def run(total: int, constrained_ids: list[str], allowed_tokens: list[int]) -> None:
decode_reqs = [(f"r{i}", SimpleNamespace()) for i in range(total)]
grammar = SimpleNamespace(
structured_output_request_ids=constrained_ids,
grammar_bitmask=np.vstack([bitmask(t) for t in allowed_tokens]),
)
logits = mx.zeros((1, total, VOCAB))
cu = list(range(total + 1))
sched = SimpleNamespace(scheduled_spec_decode_tokens={})
spy = _Spy()
with patch.object(so, "np", spy):
result = MetalStructuredOutputApplier().apply_paged(
sched, grammar, decode_reqs, [], cu, total, logits
)
r = np.array(result)
print(f"\n{'─' * 55}")
print(f" total_tokens = {total:>4} constrained = {len(constrained_ids)}")
print(f"{'─' * 55}")
for rid, tok in zip(constrained_ids, allowed_tokens):
row = int(rid[1:])
finite = np.where(np.isfinite(r[0, row]))[0].tolist()
ok = finite == [tok]
print(f" row {row:>3} ({rid}) allowed={tok:<5} finite tokens: {finite} {'✓' if ok else '✗'}")
unconstrained = [i for i in range(total) if f"r{i}" not in constrained_ids][:3]
for row in unconstrained:
all_fin = bool(np.all(np.isfinite(r[0, row])))
print(f" row {row:>3} (unconstrained) all finite: {all_fin} {'✓' if all_fin else '✗'}")
print()
if spy.shapes:
s = spy.shapes[0]
full = total * VOCAB
actual = s[0] * s[1]
saved = 100 * (full - actual) / full
print(f" CPU transfer : np.array() shape = {s}")
print(f" floats sent : {actual:>10,} (= {s[0]} rows × {s[1]:,} vocab)")
print(f" full plane : {full:>10,} (= {total} rows × {VOCAB:,} vocab)")
print(f" saved : {full - actual:>10,} ({saved:.1f}%)")
print()
if __name__ == "__main__":
run(total=128, constrained_ids=["r0", "r1"], allowed_tokens=[42, 1337])
run(total=64, constrained_ids=["r0"], allowed_tokens=[7])
run(total=32, constrained_ids=["r0", "r1", "r2"], allowed_tokens=[100, 200, 300])─────────────────────────────────────────────────────── CPU transfer : np.array() shape = (2, 32000) ─────────────────────────────────────────────────────── CPU transfer : np.array() shape = (1, 32000) ─────────────────────────────────────────────────────── CPU transfer : np.array() shape = (3, 32000) |
|
Build failed and: Defensive assert could be RuntimeError |
…en guard - Convert defensive assert to raise RuntimeError - Apply ruff format Signed-off-by: Injae Ryou <injaeryou@gmail.com>
|
@ericcurtin I fixed it. |
LxYuan0420
left a comment
There was a problem hiding this comment.
Great work! I verified end to end on this branch with Qwen/Qwen3-0.6B and sent a /v1/chat/completions request with response_format.type=json_schema, and got valid schema-constrained JSON back.
implement grammar/structured output support for Metal paged path
Fixes: #238
Test
# PyTest pytest tests/test_grammar_bitmask.py# Reproduce python reproduce_grammar_238.pyreproduce_grammar_238.py
Test Result
PyTest
Reproduce