Skip to content

Commit a62c183

Browse files
add sink and local for fla benchmark (#27)
* add sink and loca for fla benchmark
1 parent c72df80 commit a62c183

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

benchmark/bench_flash_attn.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def flash_attn_baseline(
1010
k_cache,
1111
v_cache,
1212
causal,
13+
window_size,
1314
softmax_scale,
15+
softmax_sink,
1416
cache_seqlens,
1517
page_table,
1618
cu_seqlens_q,
@@ -22,6 +24,8 @@ def flash_attn_baseline(
2224
k_cache,
2325
v_cache,
2426
causal=causal,
27+
softmax_sink=softmax_sink,
28+
window_size=window_size,
2529
softmax_scale=softmax_scale,
2630
page_table=page_table,
2731
cache_seqlens=cache_seqlens,
@@ -34,20 +38,39 @@ def flash_attn_baseline(
3438

3539
# Benchmark configurations
3640
causal = [True, False]
41+
local = [True, False]
42+
use_softmax_sink = [True, False]
3743
batch_size = [1, 16]
3844
q_seq_length_range = [1, 512, 1024]
3945
kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384]
4046
page_size_range = [32, 64, 128]
4147
configs = list(
42-
product(
43-
causal, batch_size, q_seq_length_range, kv_seq_length_range, page_size_range
48+
filter(
49+
lambda cfg: not (cfg[0] and cfg[1]),
50+
product(
51+
causal,
52+
local,
53+
use_softmax_sink,
54+
batch_size,
55+
q_seq_length_range,
56+
kv_seq_length_range,
57+
page_size_range,
58+
),
4459
)
4560
)
4661

4762

4863
@triton.testing.perf_report(
4964
triton.testing.Benchmark(
50-
x_names=["causal", "batch_size", "q_seq_length", "kv_seq_length", "page_size"],
65+
x_names=[
66+
"causal",
67+
"local",
68+
"use_softmax_sink",
69+
"batch_size",
70+
"q_seq_length",
71+
"kv_seq_length",
72+
"page_size",
73+
],
5174
x_vals=[list(c) for c in configs],
5275
line_arg="provider",
5376
line_vals=["flash_attn"],
@@ -58,7 +81,16 @@ def flash_attn_baseline(
5881
args={},
5982
)
6083
)
61-
def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provider):
84+
def benchmark(
85+
causal,
86+
local,
87+
use_softmax_sink,
88+
batch_size,
89+
q_seq_length,
90+
kv_seq_length,
91+
page_size,
92+
provider,
93+
):
6294
dtype = torch.bfloat16
6395
device = torch.device("xpu")
6496

@@ -93,6 +125,11 @@ def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provid
93125
dtype=torch.int32,
94126
)
95127
max_seqlen_q = q_seq_length
128+
window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,))
129+
130+
softmax_sink = (
131+
torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None
132+
)
96133

97134
softmax_scale = 1.0 / (head_dim**0.5)
98135

@@ -105,7 +142,9 @@ def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provid
105142
k_cache.clone(),
106143
v_cache.clone(),
107144
causal=causal,
145+
window_size=window_size,
108146
softmax_scale=softmax_scale,
147+
softmax_sink=softmax_sink,
109148
cache_seqlens=cache_seqlens,
110149
page_table=page_table,
111150
cu_seqlens_q=cu_seqlens_q,
@@ -119,3 +158,4 @@ def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provid
119158

120159
if __name__ == "__main__":
121160
benchmark.run(print_data=True)
161+
print("Benchmark finished!")

0 commit comments

Comments
 (0)