Skip to content

Commit e0d4d44

Browse files
committed
Rename parameter "softmax_sink" to "sinks" for flash_attn_with_kvcache kernel, as "sinks" is used in the sglang framework
1 parent 1abaed2 commit e0d4d44

File tree

6 files changed

+23
-23
lines changed

6 files changed

+23
-23
lines changed

benchmark/bench_flash_attn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def flash_attn_baseline(
1212
causal,
1313
window_size,
1414
softmax_scale,
15-
softmax_sink,
15+
sinks,
1616
cache_seqlens,
1717
page_table,
1818
cu_seqlens_q,
@@ -24,7 +24,7 @@ def flash_attn_baseline(
2424
k_cache,
2525
v_cache,
2626
causal=causal,
27-
softmax_sink=softmax_sink,
27+
sinks=sinks,
2828
window_size=window_size,
2929
softmax_scale=softmax_scale,
3030
page_table=page_table,
@@ -39,7 +39,7 @@ def flash_attn_baseline(
3939
# Benchmark configurations
4040
causal = [True, False]
4141
local = [True, False]
42-
use_softmax_sink = [True, False]
42+
use_sinks = [True, False]
4343
batch_size = [1, 16]
4444
q_seq_length_range = [1, 512, 1024]
4545
kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384]
@@ -50,7 +50,7 @@ def flash_attn_baseline(
5050
product(
5151
causal,
5252
local,
53-
use_softmax_sink,
53+
use_sinks,
5454
batch_size,
5555
q_seq_length_range,
5656
kv_seq_length_range,
@@ -65,7 +65,7 @@ def flash_attn_baseline(
6565
x_names=[
6666
"causal",
6767
"local",
68-
"use_softmax_sink",
68+
"use_sinks",
6969
"batch_size",
7070
"q_seq_length",
7171
"kv_seq_length",
@@ -84,7 +84,7 @@ def flash_attn_baseline(
8484
def benchmark(
8585
causal,
8686
local,
87-
use_softmax_sink,
87+
use_sinks,
8888
batch_size,
8989
q_seq_length,
9090
kv_seq_length,
@@ -127,8 +127,8 @@ def benchmark(
127127
max_seqlen_q = q_seq_length
128128
window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,))
129129

130-
softmax_sink = (
131-
torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None
130+
sinks = (
131+
torch.randn(num_heads, device=device, dtype=dtype) if use_sinks else None
132132
)
133133

134134
softmax_scale = 1.0 / (head_dim**0.5)
@@ -144,7 +144,7 @@ def benchmark(
144144
causal=causal,
145145
window_size=window_size,
146146
softmax_scale=softmax_scale,
147-
softmax_sink=softmax_sink,
147+
sinks=sinks,
148148
cache_seqlens=cache_seqlens,
149149
page_table=page_table,
150150
cu_seqlens_q=cu_seqlens_q,

include/sgl_flash_kernel_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::vector<at::Tensor> mha_fwd(
6262
std::optional<at::Tensor>& k_descale_, // (b, h_k)
6363
std::optional<at::Tensor>& v_descale_, // (b, h_k)
6464
float const softmax_scale,
65-
std::optional<const at::Tensor>& softmax_sink,
65+
std::optional<const at::Tensor>& sinks,
6666
bool is_causal,
6767
int window_size_left,
6868
int window_size_right,

python/sgl_kernel/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def flash_attn_with_kvcache(
4949
k_descale: Optional[torch.Tensor] = None,
5050
v_descale: Optional[torch.Tensor] = None,
5151
softmax_scale=None,
52-
softmax_sink=None,
52+
sinks=None,
5353
causal=False,
5454
window_size=(-1, -1), # -1 means infinite context window
5555
softcap=0.0, # 0.0 means deactivated
@@ -205,7 +205,7 @@ def flash_attn_with_kvcache(
205205
k_descale,
206206
v_descale,
207207
softmax_scale,
208-
softmax_sink,
208+
sinks,
209209
causal,
210210
window_size[0],
211211
window_size[1],

src/sycl/chunked_prefill.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ std::vector<at::Tensor> mha_fwd(
489489
std::optional<at::Tensor>& k_descale_, // (b, h_k)
490490
std::optional<at::Tensor>& v_descale_, // (b, h_k)
491491
const float softmax_scale_,
492-
std::optional<const at::Tensor>& softmax_sink_,
492+
std::optional<const at::Tensor>& sinks_,
493493
bool is_causal,
494494
int window_size_left,
495495
int window_size_right,
@@ -643,8 +643,8 @@ std::vector<at::Tensor> mha_fwd(
643643

644644
// Set the different scale values.
645645
params.scale_softmax = softmax_scale;
646-
bool use_sink = softmax_sink_.has_value();
647-
params.sink_softmax = use_sink ? softmax_sink_.value().data_ptr() : nullptr;
646+
bool use_sink = sinks_.has_value();
647+
params.sink_softmax = use_sink ? sinks_.value().data_ptr() : nullptr;
648648

649649
params.softcap = softcap;
650650

src/torch_extension_sycl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
8282
" Tensor? k_descale,"
8383
" Tensor? v_descale,"
8484
" float softmax_scale,"
85-
" Tensor? softmax_sink,"
85+
" Tensor? sinks,"
8686
" bool is_causal,"
8787
" int window_size_left,"
8888
" int window_size_right,"

tests/test_flash_attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def generate_qkv(
489489
# )
490490
@pytest.mark.parametrize("causal,local", [(False, True), (False, False), (True, False)])
491491
# @pytest.mark.parametrize("causal,local", [(True, False)])
492-
@pytest.mark.parametrize("use_softmax_sink", [True, False])
492+
@pytest.mark.parametrize("use_sinks", [True, False])
493493
# @pytest.mark.parametrize(
494494
# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
495495
# )
@@ -556,7 +556,7 @@ def test_flash_attn_kvcache(
556556
seqlen_new_eq_seqlen_q,
557557
causal,
558558
local,
559-
use_softmax_sink,
559+
use_sinks,
560560
new_kv,
561561
mha_type,
562562
dtype,
@@ -586,8 +586,8 @@ def test_flash_attn_kvcache(
586586
assert nheads % nheads_k == 0
587587
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
588588
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
589-
if use_softmax_sink:
590-
softmax_sink = torch.randn(nheads, device=device, dtype=dtype_ref)
589+
if use_sinks:
590+
sinks = torch.randn(nheads, device=device, dtype=dtype_ref)
591591
if dtype == torch.float8_e4m3fn or not is_hopper():
592592
# for fp8 and ampere arch, we not support v head dim != qk head dim
593593
dv_vals = [d]
@@ -831,7 +831,7 @@ def test_flash_attn_kvcache(
831831
k_cache_rep,
832832
v_cache_rep,
833833
softmax_scale,
834-
softmax_sink if use_softmax_sink else None,
834+
sinks if use_sinks else None,
835835
query_padding_mask,
836836
key_padding_mask,
837837
causal=causal,
@@ -844,7 +844,7 @@ def test_flash_attn_kvcache(
844844
k_cache_rep,
845845
v_cache_rep,
846846
softmax_scale,
847-
softmax_sink if use_softmax_sink else None,
847+
sinks if use_sinks else None,
848848
query_padding_mask,
849849
key_padding_mask,
850850
causal=causal,
@@ -905,7 +905,7 @@ def test_flash_attn_kvcache(
905905
causal=causal,
906906
window_size=window_size,
907907
softmax_scale=softmax_scale,
908-
softmax_sink=softmax_sink if use_softmax_sink else None,
908+
sinks=sinks if use_sinks else None,
909909
rotary_interleaved=rotary_interleaved,
910910
scheduler_metadata=scheduler_metadata,
911911
num_splits=num_splits,

0 commit comments

Comments
 (0)