@@ -50,6 +50,107 @@ def _seq_lens_from_mask(
5050 return None , False
5151
5252
53+ def _check_mask_supported (attention_mask : torch .Tensor | None , seq_q : int ) -> None :
54+ """Reject attention masks this wrapper would silently misread.
55+
56+ The wrapper only derives right-padded per-sequence lengths from 2D
57+ ``[batch, q_len]`` masks; anything else either loses padding info (4D
58+ masks) or corrupts the varlen metadata (FA2-style ``[batch, kv_len]``
59+ masks during cached decode).
60+ """
61+
62+ def _unsupported (reason ):
63+ return NotImplementedError (
64+ f"The ModelOpt Triton attention kernel does not support { reason } . "
65+ "Use unpadded (or uniform-length) right-padded inputs."
66+ )
67+
68+ if attention_mask is None :
69+ return
70+ if attention_mask .dim () == 2 :
71+ if attention_mask .shape [1 ] != seq_q :
72+ # FA2-style [batch, kv_len] mask during cached decode: the wrapper
73+ # would misread KV lengths as query lengths (out-of-bounds access).
74+ raise _unsupported ("padded batches during cached decode" )
75+ mask_bool = attention_mask .to (torch .bool )
76+ if not mask_bool [:, 0 ].all ():
77+ raise _unsupported ("left-padded inputs" )
78+ # ``_seq_lens_from_mask`` derives lengths via ``sum(dim=1)``, which is only
79+ # correct when each row is a contiguous run of valid tokens followed by
80+ # padding. A hole (e.g. ``[1, 0, 1]``) would sum to the right count but
81+ # place the valid tokens at the wrong positions, so reject non-right-padded
82+ # masks (any valid token after a pad == row not monotonically non-increasing).
83+ if not (mask_bool [:, :- 1 ].int () >= mask_bool [:, 1 :].int ()).all ():
84+ raise _unsupported ("non-contiguously padded inputs" )
85+ return
86+ # 4D [batch, 1, q, kv] masks are ignored by the wrapper, which is safe only
87+ # when they encode pure causal structure (the kernel masks causally itself).
88+ # In a causal mask the newest query row sees every position; any masked
89+ # entry there means padding, windowing, or a non-causal/bias pattern.
90+ last_row = attention_mask [..., - 1 , :]
91+ hidden = ~ last_row if attention_mask .dtype == torch .bool else last_row != 0
92+ if hidden .any ():
93+ raise _unsupported ("masks carrying padding or non-causal structure" )
94+
95+
96+ def validate_triton_attention_envelope (
97+ module : nn .Module ,
98+ query : torch .Tensor ,
99+ key : torch .Tensor ,
100+ attention_mask : torch .Tensor | None ,
101+ ** kwargs ,
102+ ) -> None :
103+ """Raise ``NotImplementedError`` for inputs outside this wrapper/kernel envelope.
104+
105+ These limits do not come from the quantization or sparsity features layered
106+ on top — they document what the ``triton_fa`` kernel (causal or single-token
107+ decode only; no sliding window, attention sinks, logit softcapping, or
108+ dropout; head_dim >= 16) and this wrapper's varlen-metadata derivation
109+ (right-padded 2D masks only; no multi-token forwards over a longer KV cache)
110+ support. Callers that route arbitrary HF models onto the kernel dynamically
111+ (e.g. the quantization plugin's p_bmm_quantizer dispatch) should call this
112+ before dispatching, so unsupported models fail loudly instead of silently
113+ computing wrong attention. The sparse-attention path predates these checks
114+ and does not yet enforce them.
115+ """
116+ # Mistral-style models pass sliding_window as an interface kwarg instead of
117+ # setting it on the attention module, so check both.
118+ if getattr (module , "sliding_window" , None ) or kwargs .get ("sliding_window" ):
119+ raise NotImplementedError (
120+ "The ModelOpt Triton attention kernel does not support sliding-window attention layers."
121+ )
122+ # Semantic attention arguments the kernel does not implement: dropping them
123+ # would change the attention math.
124+ for name , reason in (("s_aux" , "attention sinks" ), ("softcap" , "logit softcapping" )):
125+ if kwargs .get (name ) is not None :
126+ raise NotImplementedError (
127+ f"The ModelOpt Triton attention kernel does not support { reason } ('{ name } ')."
128+ )
129+ if kwargs .get ("is_causal" ) is False or getattr (module , "is_causal" , True ) is False :
130+ raise NotImplementedError (
131+ "The ModelOpt Triton attention kernel does not support non-causal attention."
132+ )
133+ if kwargs .get ("dropout" ):
134+ raise NotImplementedError (
135+ "The ModelOpt Triton attention kernel does not support attention dropout; "
136+ "set attention_dropout=0 for training."
137+ )
138+ if query .shape [- 1 ] < 16 :
139+ raise NotImplementedError (
140+ f"The ModelOpt Triton attention kernel requires head_dim >= 16, got { query .shape [- 1 ]} ."
141+ )
142+ seq_q , seq_k = query .shape [2 ], key .shape [2 ]
143+ if seq_q > 1 and seq_k != seq_q :
144+ # The wrapper only passes K-side varlen metadata for single-token decode;
145+ # multi-token forwards over a longer KV cache would mis-index K/V.
146+ raise NotImplementedError (
147+ "The ModelOpt Triton attention kernel does not support multi-token "
148+ "forwards over a longer KV cache (chunked prefill or "
149+ "assisted/speculative decoding)."
150+ )
151+ _check_mask_supported (attention_mask , seq_q )
152+
153+
53154def triton_attention_forward (
54155 module : nn .Module ,
55156 query : torch .Tensor ,
@@ -58,6 +159,8 @@ def triton_attention_forward(
58159 attention_mask : torch .Tensor | None ,
59160 scaling : float ,
60161 dropout : float = 0.0 ,
162+ p_qdq : str | None = None ,
163+ p_qdq_scale : float | None = None ,
61164 ** kwargs ,
62165) -> tuple [torch .Tensor , None ]:
63166 """Attention forward compatible with HF AttentionInterface.
@@ -75,6 +178,12 @@ def triton_attention_forward(
75178 Other formats (e.g. 4D causal masks) are ignored.
76179 scaling: Softmax scale (e.g. 1/sqrt(head_dim)).
77180 dropout: Ignored (kernel has no dropout); use 0 for eval.
181+ p_qdq: Optional softmax fake quant-dequant mode ("fp8" or
182+ "nvfp4") forwarded to the kernel. Not passed by HF dispatch;
183+ used by direct callers such as the quantization plugin.
184+ p_qdq_scale: Optional per-tensor quantization scale for the
185+ softmax qdq; None uses the kernel default of 1.0 (an effective
186+ amax of 448 for FP8 / 6 * 448 for NVFP4).
78187 **kwargs: Reserved for future extensions.
79188
80189 Returns:
@@ -121,7 +230,7 @@ def triton_attention_forward(
121230 trials = getattr (method , "_threshold_trials" , None )
122231 # Deferred: the package __init__ imports this module, so importing
123232 # attention_calibrate at module top would be circular.
124- from modelopt .torch .kernels .common .attention import attention_calibrate
233+ from modelopt .torch .kernels .sparsity .attention . calibrate import attention_calibrate
125234
126235 if trials and attention_calibrate is not None :
127236 o , counters = attention_calibrate (q , k , v , ** kw , threshold_trials = trials )
@@ -153,6 +262,11 @@ def triton_attention_forward(
153262 if threshold :
154263 kw ["skip_softmax_threshold" ] = threshold
155264
265+ if p_qdq is not None :
266+ kw ["p_qdq" ] = p_qdq
267+ if p_qdq_scale is not None :
268+ kw ["p_qdq_scale" ] = p_qdq_scale
269+
156270 o = attention (q , k , v , ** kw )
157271
158272 attn_output = o .view (batch , seq_len , num_heads , head_dim )
@@ -188,4 +302,5 @@ def register_triton_attention() -> bool:
188302__all__ = [
189303 "register_triton_attention" ,
190304 "triton_attention_forward" ,
305+ "validate_triton_attention_envelope" ,
191306]
0 commit comments