diff --git a/tritonbench/kernels/blackwell_triton_fused_attention.py b/tritonbench/kernels/blackwell_triton_fused_attention.py index f5e401ede..90dc93b39 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention.py @@ -76,11 +76,8 @@ def _attn_fwd_subtile( qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) # -- compute correction factor alpha = tl.math.exp2(m_i - m_ij) - if not FADD2_REDUCE: - l_ij = tl.sum(p, 1) # -- update output accumulator -- BM: tl.constexpr = acc.shape[0] @@ -98,6 +95,7 @@ def _attn_fwd_subtile( else: acc = acc * alpha[:, None] + p = tl.math.exp2(qk) PM: tl.constexpr = p.shape[0] PN: tl.constexpr = p.shape[1] if FADD2_REDUCE: @@ -105,6 +103,8 @@ def _attn_fwd_subtile( l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) l_i0 = l_i0 * alpha + l_ij0 l_i1 = l_i1 * alpha + l_ij1 + else: + l_ij = tl.sum(p, 1) # prepare p and v for the dot p = p.to(dtype) @@ -707,6 +707,7 @@ def grid_debug(META): warp_specialize=warp_specialize, OUTER_LOOP=True, dtype=torch_dtype_to_triton(q.dtype), + data_partition_factor=2, **extra_kern_args, ) else: diff --git a/tritonbench/kernels/blackwell_triton_fused_attention_dp.py b/tritonbench/kernels/blackwell_triton_fused_attention_dp.py index 446440fd7..f8db8ab77 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention_dp.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention_dp.py @@ -80,11 +80,8 @@ def _attn_fwd_subtile( qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) # -- compute correction factor alpha = tl.math.exp2(m_i - m_ij) - if not FADD2_REDUCE: - l_ij = tl.sum(p, 1) # -- update output accumulator -- BM: tl.constexpr = acc.shape[0] @@ -104,6 +101,7 @@ def _attn_fwd_subtile( # update m_i and l_i # place this at the end of the loop to reduce register pressure + p = tl.math.exp2(qk) PM: tl.constexpr = p.shape[0] PN: tl.constexpr = p.shape[1] if FADD2_REDUCE: @@ -111,6 +109,8 @@ def _attn_fwd_subtile( l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) l_i0 = l_i0 * alpha + l_ij0 l_i1 = l_i1 * alpha + l_ij1 + else: + l_ij = tl.sum(p, 1) # We can potentially move these to be before updating l_ij, so the dot # is not blocked.