Skip to content

Commit 56ae5b5

Browse files
committed
Add p quantization to our triton fa kernel
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent 07ce8e5 commit 56ae5b5

20 files changed

Lines changed: 1024 additions & 45 deletions

File tree

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]:
8989
if "quantizer" not in key:
9090
return ("copy", key, value)
9191

92-
# Skip softmax_quantizer and lm_head quantizers (not needed in vLLM).
93-
if "softmax_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key):
92+
# Skip p_bmm_quantizer (softmax-P) and lm_head quantizers (not needed in vLLM).
93+
if "p_bmm_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key):
9494
return ("skip", None, None)
9595

9696
# Check if this is a q/k/v projection that needs merging

modelopt/torch/kernels/common/attention/__init__.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515

1616
"""Shared Triton kernels for modelopt (attention, quantization, etc.)."""
1717

18+
from collections.abc import Callable
19+
1820
import torch
1921

2022
from modelopt.torch.utils import import_plugin
2123

2224
IS_AVAILABLE = False
23-
attention = None
24-
attention_calibrate = None
25-
register_triton_attention = None
25+
attention: Callable | None = None
26+
register_triton_attention: Callable | None = None
27+
triton_attention_forward: Callable | None = None
28+
validate_triton_attention_envelope: Callable | None = None
2629

2730
if torch.cuda.is_available():
2831
with import_plugin(
@@ -32,26 +35,19 @@
3235
"kernel. Try to install triton with `pip install triton`."
3336
),
3437
):
35-
from .triton_fa import attention as _attention
36-
37-
attention = _attention
38-
IS_AVAILABLE = True
39-
from .hf_triton_attention import register_triton_attention as _register_triton_attention
40-
41-
register_triton_attention = _register_triton_attention
42-
43-
# Calibration lives in the sparsity subpackage (skip-softmax specific).
44-
# Imported here so ``from modelopt.torch.kernels.common.attention import
45-
# attention_calibrate`` keeps working.
46-
from modelopt.torch.kernels.sparsity.attention.calibrate import (
47-
attention_calibrate as _attention_calibrate,
38+
from .hf_triton_attention import (
39+
register_triton_attention,
40+
triton_attention_forward,
41+
validate_triton_attention_envelope,
4842
)
43+
from .triton_fa import attention
4944

50-
attention_calibrate = _attention_calibrate
45+
IS_AVAILABLE = True
5146

5247
__all__ = [
5348
"IS_AVAILABLE",
5449
"attention",
55-
"attention_calibrate",
5650
"register_triton_attention",
51+
"triton_attention_forward",
52+
"validate_triton_attention_envelope",
5753
]

modelopt/torch/kernels/common/attention/hf_triton_attention.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,107 @@ def _seq_lens_from_mask(
5050
return None, False
5151

5252

53+
def _check_mask_supported(attention_mask: torch.Tensor | None, seq_q: int) -> None:
54+
"""Reject attention masks this wrapper would silently misread.
55+
56+
The wrapper only derives right-padded per-sequence lengths from 2D
57+
``[batch, q_len]`` masks; anything else either loses padding info (4D
58+
masks) or corrupts the varlen metadata (FA2-style ``[batch, kv_len]``
59+
masks during cached decode).
60+
"""
61+
62+
def _unsupported(reason):
63+
return NotImplementedError(
64+
f"The ModelOpt Triton attention kernel does not support {reason}. "
65+
"Use unpadded (or uniform-length) right-padded inputs."
66+
)
67+
68+
if attention_mask is None:
69+
return
70+
if attention_mask.dim() == 2:
71+
if attention_mask.shape[1] != seq_q:
72+
# FA2-style [batch, kv_len] mask during cached decode: the wrapper
73+
# would misread KV lengths as query lengths (out-of-bounds access).
74+
raise _unsupported("padded batches during cached decode")
75+
mask_bool = attention_mask.to(torch.bool)
76+
if not mask_bool[:, 0].all():
77+
raise _unsupported("left-padded inputs")
78+
# ``_seq_lens_from_mask`` derives lengths via ``sum(dim=1)``, which is only
79+
# correct when each row is a contiguous run of valid tokens followed by
80+
# padding. A hole (e.g. ``[1, 0, 1]``) would sum to the right count but
81+
# place the valid tokens at the wrong positions, so reject non-right-padded
82+
# masks (any valid token after a pad == row not monotonically non-increasing).
83+
if not (mask_bool[:, :-1].int() >= mask_bool[:, 1:].int()).all():
84+
raise _unsupported("non-contiguously padded inputs")
85+
return
86+
# 4D [batch, 1, q, kv] masks are ignored by the wrapper, which is safe only
87+
# when they encode pure causal structure (the kernel masks causally itself).
88+
# In a causal mask the newest query row sees every position; any masked
89+
# entry there means padding, windowing, or a non-causal/bias pattern.
90+
last_row = attention_mask[..., -1, :]
91+
hidden = ~last_row if attention_mask.dtype == torch.bool else last_row != 0
92+
if hidden.any():
93+
raise _unsupported("masks carrying padding or non-causal structure")
94+
95+
96+
def validate_triton_attention_envelope(
97+
module: nn.Module,
98+
query: torch.Tensor,
99+
key: torch.Tensor,
100+
attention_mask: torch.Tensor | None,
101+
**kwargs,
102+
) -> None:
103+
"""Raise ``NotImplementedError`` for inputs outside this wrapper/kernel envelope.
104+
105+
These limits do not come from the quantization or sparsity features layered
106+
on top — they document what the ``triton_fa`` kernel (causal or single-token
107+
decode only; no sliding window, attention sinks, logit softcapping, or
108+
dropout; head_dim >= 16) and this wrapper's varlen-metadata derivation
109+
(right-padded 2D masks only; no multi-token forwards over a longer KV cache)
110+
support. Callers that route arbitrary HF models onto the kernel dynamically
111+
(e.g. the quantization plugin's p_bmm_quantizer dispatch) should call this
112+
before dispatching, so unsupported models fail loudly instead of silently
113+
computing wrong attention. The sparse-attention path predates these checks
114+
and does not yet enforce them.
115+
"""
116+
# Mistral-style models pass sliding_window as an interface kwarg instead of
117+
# setting it on the attention module, so check both.
118+
if getattr(module, "sliding_window", None) or kwargs.get("sliding_window"):
119+
raise NotImplementedError(
120+
"The ModelOpt Triton attention kernel does not support sliding-window attention layers."
121+
)
122+
# Semantic attention arguments the kernel does not implement: dropping them
123+
# would change the attention math.
124+
for name, reason in (("s_aux", "attention sinks"), ("softcap", "logit softcapping")):
125+
if kwargs.get(name) is not None:
126+
raise NotImplementedError(
127+
f"The ModelOpt Triton attention kernel does not support {reason} ('{name}')."
128+
)
129+
if kwargs.get("is_causal") is False or getattr(module, "is_causal", True) is False:
130+
raise NotImplementedError(
131+
"The ModelOpt Triton attention kernel does not support non-causal attention."
132+
)
133+
if kwargs.get("dropout"):
134+
raise NotImplementedError(
135+
"The ModelOpt Triton attention kernel does not support attention dropout; "
136+
"set attention_dropout=0 for training."
137+
)
138+
if query.shape[-1] < 16:
139+
raise NotImplementedError(
140+
f"The ModelOpt Triton attention kernel requires head_dim >= 16, got {query.shape[-1]}."
141+
)
142+
seq_q, seq_k = query.shape[2], key.shape[2]
143+
if seq_q > 1 and seq_k != seq_q:
144+
# The wrapper only passes K-side varlen metadata for single-token decode;
145+
# multi-token forwards over a longer KV cache would mis-index K/V.
146+
raise NotImplementedError(
147+
"The ModelOpt Triton attention kernel does not support multi-token "
148+
"forwards over a longer KV cache (chunked prefill or "
149+
"assisted/speculative decoding)."
150+
)
151+
_check_mask_supported(attention_mask, seq_q)
152+
153+
53154
def triton_attention_forward(
54155
module: nn.Module,
55156
query: torch.Tensor,
@@ -58,6 +159,8 @@ def triton_attention_forward(
58159
attention_mask: torch.Tensor | None,
59160
scaling: float,
60161
dropout: float = 0.0,
162+
p_qdq: str | None = None,
163+
p_qdq_scale: float | None = None,
61164
**kwargs,
62165
) -> tuple[torch.Tensor, None]:
63166
"""Attention forward compatible with HF AttentionInterface.
@@ -75,6 +178,12 @@ def triton_attention_forward(
75178
Other formats (e.g. 4D causal masks) are ignored.
76179
scaling: Softmax scale (e.g. 1/sqrt(head_dim)).
77180
dropout: Ignored (kernel has no dropout); use 0 for eval.
181+
p_qdq: Optional softmax fake quant-dequant mode ("fp8" or
182+
"nvfp4") forwarded to the kernel. Not passed by HF dispatch;
183+
used by direct callers such as the quantization plugin.
184+
p_qdq_scale: Optional per-tensor quantization scale for the
185+
softmax qdq; None uses the kernel default of 1.0 (an effective
186+
amax of 448 for FP8 / 6 * 448 for NVFP4).
78187
**kwargs: Reserved for future extensions.
79188
80189
Returns:
@@ -121,7 +230,7 @@ def triton_attention_forward(
121230
trials = getattr(method, "_threshold_trials", None)
122231
# Deferred: the package __init__ imports this module, so importing
123232
# attention_calibrate at module top would be circular.
124-
from modelopt.torch.kernels.common.attention import attention_calibrate
233+
from modelopt.torch.kernels.sparsity.attention.calibrate import attention_calibrate
125234

126235
if trials and attention_calibrate is not None:
127236
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)
@@ -153,6 +262,11 @@ def triton_attention_forward(
153262
if threshold:
154263
kw["skip_softmax_threshold"] = threshold
155264

265+
if p_qdq is not None:
266+
kw["p_qdq"] = p_qdq
267+
if p_qdq_scale is not None:
268+
kw["p_qdq_scale"] = p_qdq_scale
269+
156270
o = attention(q, k, v, **kw)
157271

158272
attn_output = o.view(batch, seq_len, num_heads, head_dim)
@@ -188,4 +302,5 @@ def register_triton_attention() -> bool:
188302
__all__ = [
189303
"register_triton_attention",
190304
"triton_attention_forward",
305+
"validate_triton_attention_envelope",
191306
]

modelopt/torch/kernels/common/attention/triton_fa.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
_apply_sparse_nm_to_qk_tile: Any = None
4343
_is_dense_region: Any = None
4444
_skip_softmax_decision: Any = None
45+
_p_qdq_fp8: Any = None
46+
_p_qdq_nvfp4: Any = None
4547

4648

4749
def _load_sparsity_helpers() -> None:
@@ -62,6 +64,20 @@ def _load_sparsity_helpers() -> None:
6264
_skip_softmax_decision = _skip
6365

6466

67+
def _load_p_qdq_helpers() -> None:
68+
global _p_qdq_fp8, _p_qdq_nvfp4
69+
if _p_qdq_fp8 is None:
70+
from modelopt.torch.kernels.quantization.attention.p_qdq import _p_qdq_nvfp4 as _nvfp4
71+
from modelopt.torch.kernels.quantization.common.fp8_quant import fp8_scalar_qdq as _fp8
72+
73+
_p_qdq_fp8 = _fp8
74+
_p_qdq_nvfp4 = _nvfp4
75+
76+
77+
# Maps the public p_qdq option to the kernel's P_QDQ constexpr.
78+
_P_QDQ_MODES = {None: 0, "fp8": 1, "nvfp4": 2}
79+
80+
6581
LOG2E: float = 1.44269504088896
6682

6783
# ---------------------------------------------------------------------------
@@ -246,6 +262,8 @@ def _attn_fwd(
246262
DENSE_RECENT_TOKENS: tl.constexpr = 64, # Recent KV tokens kept dense (BLOCK_N-independent)
247263
APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores
248264
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) in the kernel's scaled log2 score space
265+
P_QDQ: tl.constexpr = 0, # Fake quant-dequant of softmax P: 0=off, 1=FP8 E4M3, 2=NVFP4
266+
p_qdq_scale=1.0, # Per-tensor scale for softmax qdq (runtime scalar; amax/448 or amax/(6*448))
249267
Sparsity_total=None, # Optional int64 scalar for counting total tiles (atomic)
250268
Sparsity_skipped=None, # Optional int64 scalar for counting skipped tiles (atomic)
251269
MEASURE_SPARSITY: tl.constexpr = False, # When True, count total/skipped tiles via atomic adds
@@ -383,6 +401,14 @@ def _attn_fwd(
383401
row_sum = row_sum * correction + l_new
384402
acc = acc * correction[:, None]
385403

404+
# --- Optional softmax quant-dequant (emulates quantized P @ V) ---
405+
# row_sum keeps the unquantized p: deployment kernels compute the
406+
# softmax denominator in fp32 and only feed quantized P to BMM2.
407+
if P_QDQ == 1:
408+
p = _p_qdq_fp8(p, p_qdq_scale)
409+
elif P_QDQ == 2:
410+
p = _p_qdq_nvfp4(p, p_qdq_scale, BLOCK_M, BLOCK_N)
411+
386412
# Load V and accumulate
387413
if IS_PAGED:
388414
v = _load_paged_v_tile(
@@ -806,6 +832,8 @@ def forward(
806832
dense_recent_tokens,
807833
skip_softmax_threshold,
808834
measure_sparsity,
835+
p_qdq_mode,
836+
p_qdq_scale,
809837
k_cache,
810838
v_cache,
811839
block_table,
@@ -903,6 +931,8 @@ def forward(
903931
"DENSE_RECENT_TOKENS": dense_recent_tokens,
904932
"APPLY_SKIP_SOFTMAX": apply_skip,
905933
"SKIP_THRESHOLD_LOG2": skip_threshold_log2,
934+
"P_QDQ": p_qdq_mode,
935+
"p_qdq_scale": p_qdq_scale,
906936
"Sparsity_total": sparsity_total,
907937
"Sparsity_skipped": sparsity_skipped,
908938
"MEASURE_SPARSITY": do_measure,
@@ -1106,6 +1136,8 @@ def backward(ctx, grad_output):
11061136
None, # dense_recent_tokens
11071137
None, # skip_softmax_threshold
11081138
None, # measure_sparsity
1139+
None, # p_qdq_mode
1140+
None, # p_qdq_scale
11091141
None, # k_cache
11101142
None, # v_cache
11111143
None, # block_table
@@ -1132,6 +1164,8 @@ def attention(
11321164
dense_recent_tokens: int = 64,
11331165
skip_softmax_threshold: float | None = None,
11341166
measure_sparsity: bool = False,
1167+
p_qdq: str | None = None,
1168+
p_qdq_scale: float = 1.0,
11351169
k_cache: torch.Tensor | None = None,
11361170
v_cache: torch.Tensor | None = None,
11371171
block_table: torch.Tensor | None = None,
@@ -1169,6 +1203,26 @@ def attention(
11691203
and skipped tiles via atomic counters. The counts are stored as
11701204
``_sparsity_total`` and ``_sparsity_skipped`` attributes on the
11711205
returned output tensor.
1206+
p_qdq: Fake quant-dequant of the softmax probabilities ``P``
1207+
before the ``P @ V`` matmul (BMM2), emulating quantized attention.
1208+
``"fp8"`` round-trips P through FP8 E4M3 with a static per-tensor
1209+
scale (amax = 1.0, exact since the kernel's unnormalized P is in
1210+
[0, 1]). ``"nvfp4"`` applies the two-level NVFP4 recipe: E2M1
1211+
elements with one FP8 E4M3 scale per 16 elements along the key
1212+
dimension (the BMM2 contraction axis; every autotuned BLOCK_N is
1213+
a multiple of 16). The softmax denominator stays unquantized, as
1214+
in deployment kernels. The backward pass uses the straight-through
1215+
estimator: gradients are computed from the unquantized P, matching
1216+
QAT references that keep the backward dots in high precision.
1217+
Set to ``None`` to disable.
1218+
p_qdq_scale: Per-tensor quantization scale for the softmax qdq
1219+
(standard convention ``q = cast(p / scale) * scale``). For FP8
1220+
this is ``amax / 448``; for NVFP4 it is the global scale
1221+
``amax / (6 * 448)``. The default of 1.0 corresponds to an
1222+
effective amax of 448 (FP8) or 6 * 448 (NVFP4) — a direct cast
1223+
of the kernel's unnormalized P in [0, 1]. A runtime scalar —
1224+
user-set or calibrated values do not recompile the kernel.
1225+
Out-of-range values saturate.
11721226
k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim].
11731227
When provided, K/V are read from paged cache via block_table
11741228
instead of from contiguous k/v tensors.
@@ -1186,7 +1240,20 @@ def attention(
11861240
require grad, because the saved ``k``/``v`` are dummy tensors in paged
11871241
mode and dK/dV would be silently incorrect.
11881242
"""
1243+
# Both loaders must run unconditionally: Triton computes a kernel's
1244+
# dependency hash once, on the first call, walking the full AST. If the
1245+
# qdq helpers were still None at that point, their source would be
1246+
# permanently excluded from the cache key and later edits to them would
1247+
# silently reuse stale compiled kernels from the on-disk cache.
11891248
_load_sparsity_helpers()
1249+
_load_p_qdq_helpers()
1250+
if p_qdq not in _P_QDQ_MODES:
1251+
raise ValueError(
1252+
f"p_qdq must be one of {sorted(k for k in _P_QDQ_MODES if k)} or None, got {p_qdq!r}"
1253+
)
1254+
p_qdq_mode = _P_QDQ_MODES[p_qdq]
1255+
if p_qdq_mode and not (math.isfinite(p_qdq_scale) and p_qdq_scale > 0):
1256+
raise ValueError(f"p_qdq_scale must be a finite positive value, got {p_qdq_scale}")
11901257
sm_scale = 1.0 / (q.shape[2] ** 0.5) if softmax_scale is None else softmax_scale
11911258
return _Attention.apply(
11921259
q,
@@ -1206,6 +1273,8 @@ def attention(
12061273
dense_recent_tokens,
12071274
skip_softmax_threshold,
12081275
measure_sparsity,
1276+
p_qdq_mode,
1277+
p_qdq_scale,
12091278
k_cache,
12101279
v_cache,
12111280
block_table,

modelopt/torch/kernels/quantization/attention/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Quantization-specific attention kernel pieces (placeholder for combined sparse+quant path)."""
16+
"""Quantization-specific attention kernel pieces.
17+
18+
``p_qdq.py`` holds the softmax-P (``p_bmm_quantizer``) quant-dequant
19+
``@triton.jit`` helpers invoked by the unified flash-attention kernel in
20+
``common/attention/triton_fa.py`` under its ``P_QDQ`` constexpr guard.
21+
Only NVFP4 needs a P-specific helper (tiling and block-amax policy on top of
22+
``quantization/gemm/nvfp4_quant.py``); the FP8 mode uses
23+
``quantization/common/fp8_quant.fp8_scalar_qdq`` directly.
24+
"""

0 commit comments

Comments
 (0)