@@ -10,7 +10,9 @@ def flash_attn_baseline(
1010 k_cache ,
1111 v_cache ,
1212 causal ,
13+ window_size ,
1314 softmax_scale ,
15+ softmax_sink ,
1416 cache_seqlens ,
1517 page_table ,
1618 cu_seqlens_q ,
@@ -22,6 +24,8 @@ def flash_attn_baseline(
2224 k_cache ,
2325 v_cache ,
2426 causal = causal ,
27+ softmax_sink = softmax_sink ,
28+ window_size = window_size ,
2529 softmax_scale = softmax_scale ,
2630 page_table = page_table ,
2731 cache_seqlens = cache_seqlens ,
@@ -34,20 +38,39 @@ def flash_attn_baseline(
3438
3539# Benchmark configurations
3640causal = [True , False ]
41+ local = [True , False ]
42+ use_softmax_sink = [True , False ]
3743batch_size = [1 , 16 ]
3844q_seq_length_range = [1 , 512 , 1024 ]
3945kv_seq_length_range = [512 , 1024 , 2048 , 4096 , 8192 , 16384 ]
4046page_size_range = [32 , 64 , 128 ]
4147configs = list (
42- product (
43- causal , batch_size , q_seq_length_range , kv_seq_length_range , page_size_range
48+ filter (
49+ lambda cfg : not (cfg [0 ] and cfg [1 ]),
50+ product (
51+ causal ,
52+ local ,
53+ use_softmax_sink ,
54+ batch_size ,
55+ q_seq_length_range ,
56+ kv_seq_length_range ,
57+ page_size_range ,
58+ ),
4459 )
4560)
4661
4762
4863@triton .testing .perf_report (
4964 triton .testing .Benchmark (
50- x_names = ["causal" , "batch_size" , "q_seq_length" , "kv_seq_length" , "page_size" ],
65+ x_names = [
66+ "causal" ,
67+ "local" ,
68+ "use_softmax_sink" ,
69+ "batch_size" ,
70+ "q_seq_length" ,
71+ "kv_seq_length" ,
72+ "page_size" ,
73+ ],
5174 x_vals = [list (c ) for c in configs ],
5275 line_arg = "provider" ,
5376 line_vals = ["flash_attn" ],
@@ -58,7 +81,16 @@ def flash_attn_baseline(
5881 args = {},
5982 )
6083)
61- def benchmark (causal , batch_size , q_seq_length , kv_seq_length , page_size , provider ):
84+ def benchmark (
85+ causal ,
86+ local ,
87+ use_softmax_sink ,
88+ batch_size ,
89+ q_seq_length ,
90+ kv_seq_length ,
91+ page_size ,
92+ provider ,
93+ ):
6294 dtype = torch .bfloat16
6395 device = torch .device ("xpu" )
6496
@@ -93,6 +125,11 @@ def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provid
93125 dtype = torch .int32 ,
94126 )
95127 max_seqlen_q = q_seq_length
128+ window_size = (- 1 , - 1 ) if not local else torch .randint (0 , kv_seq_length , (2 ,))
129+
130+ softmax_sink = (
131+ torch .randn (num_heads , device = device , dtype = dtype ) if use_softmax_sink else None
132+ )
96133
97134 softmax_scale = 1.0 / (head_dim ** 0.5 )
98135
@@ -105,7 +142,9 @@ def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provid
105142 k_cache .clone (),
106143 v_cache .clone (),
107144 causal = causal ,
145+ window_size = window_size ,
108146 softmax_scale = softmax_scale ,
147+ softmax_sink = softmax_sink ,
109148 cache_seqlens = cache_seqlens ,
110149 page_table = page_table ,
111150 cu_seqlens_q = cu_seqlens_q ,
@@ -119,3 +158,4 @@ def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provid
119158
120159if __name__ == "__main__" :
121160 benchmark .run (print_data = True )
161+ print ("Benchmark finished!" )
0 commit comments