Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tritonbench/kernels/blackwell_triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,8 @@ def _attn_fwd_subtile(
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
else:
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
if not FADD2_REDUCE:
l_ij = tl.sum(p, 1)

# -- update output accumulator --
BM: tl.constexpr = acc.shape[0]
Expand All @@ -98,13 +95,16 @@ def _attn_fwd_subtile(
else:
acc = acc * alpha[:, None]

p = tl.math.exp2(qk)
PM: tl.constexpr = p.shape[0]
PN: tl.constexpr = p.shape[1]
if FADD2_REDUCE:
p0, p1 = p.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split()
l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
l_i0 = l_i0 * alpha + l_ij0
l_i1 = l_i1 * alpha + l_ij1
else:
l_ij = tl.sum(p, 1)

# prepare p and v for the dot
p = p.to(dtype)
Expand Down Expand Up @@ -707,6 +707,7 @@ def grid_debug(META):
warp_specialize=warp_specialize,
OUTER_LOOP=True,
dtype=torch_dtype_to_triton(q.dtype),
data_partition_factor=2,
**extra_kern_args,
)
else:
Expand Down
6 changes: 3 additions & 3 deletions tritonbench/kernels/blackwell_triton_fused_attention_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@ def _attn_fwd_subtile(
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
else:
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
if not FADD2_REDUCE:
l_ij = tl.sum(p, 1)

# -- update output accumulator --
BM: tl.constexpr = acc.shape[0]
Expand All @@ -104,13 +101,16 @@ def _attn_fwd_subtile(

# update m_i and l_i
# place this at the end of the loop to reduce register pressure
p = tl.math.exp2(qk)
PM: tl.constexpr = p.shape[0]
PN: tl.constexpr = p.shape[1]
if FADD2_REDUCE:
p0, p1 = p.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split()
l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
l_i0 = l_i0 * alpha + l_ij0
l_i1 = l_i1 * alpha + l_ij1
else:
l_ij = tl.sum(p, 1)

# We can potentially move these to be before updating l_ij, so the dot
# is not blocked.
Expand Down