Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions benchmark/bench_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def flash_attn_baseline(
causal,
window_size,
softmax_scale,
softmax_sink,
sinks,
cache_seqlens,
page_table,
cu_seqlens_q,
Expand All @@ -24,7 +24,7 @@ def flash_attn_baseline(
k_cache,
v_cache,
causal=causal,
softmax_sink=softmax_sink,
sinks=sinks,
window_size=window_size,
softmax_scale=softmax_scale,
page_table=page_table,
Expand All @@ -39,7 +39,7 @@ def flash_attn_baseline(
# Benchmark configurations
causal = [True, False]
local = [True, False]
use_softmax_sink = [True, False]
use_sinks = [True, False]
batch_size = [1, 16]
q_seq_length_range = [1, 512, 1024]
kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384]
Expand All @@ -50,7 +50,7 @@ def flash_attn_baseline(
product(
causal,
local,
use_softmax_sink,
use_sinks,
batch_size,
q_seq_length_range,
kv_seq_length_range,
Expand All @@ -65,7 +65,7 @@ def flash_attn_baseline(
x_names=[
"causal",
"local",
"use_softmax_sink",
"use_sinks",
"batch_size",
"q_seq_length",
"kv_seq_length",
Expand All @@ -84,7 +84,7 @@ def flash_attn_baseline(
def benchmark(
causal,
local,
use_softmax_sink,
use_sinks,
batch_size,
q_seq_length,
kv_seq_length,
Expand Down Expand Up @@ -127,9 +127,7 @@ def benchmark(
max_seqlen_q = q_seq_length
window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,))

softmax_sink = (
torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None
)
sinks = torch.randn(num_heads, device=device, dtype=dtype) if use_sinks else None

softmax_scale = 1.0 / (head_dim**0.5)

Expand All @@ -144,7 +142,7 @@ def benchmark(
causal=causal,
window_size=window_size,
softmax_scale=softmax_scale,
softmax_sink=softmax_sink,
sinks=sinks,
cache_seqlens=cache_seqlens,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
Expand Down
2 changes: 1 addition & 1 deletion include/sgl_flash_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ std::vector<at::Tensor> mha_fwd(
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
std::optional<const at::Tensor>& softmax_sink,
std::optional<const at::Tensor>& sinks,
bool is_causal,
int window_size_left,
int window_size_right,
Expand Down
4 changes: 2 additions & 2 deletions python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def flash_attn_with_kvcache(
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
softmax_sink=None,
sinks=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
Expand Down Expand Up @@ -205,7 +205,7 @@ def flash_attn_with_kvcache(
k_descale,
v_descale,
softmax_scale,
softmax_sink,
sinks,
causal,
window_size[0],
window_size[1],
Expand Down
6 changes: 3 additions & 3 deletions src/sycl/chunked_prefill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ std::vector<at::Tensor> mha_fwd(
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
const float softmax_scale_,
std::optional<const at::Tensor>& softmax_sink_,
std::optional<const at::Tensor>& sinks_,
bool is_causal,
int window_size_left,
int window_size_right,
Expand Down Expand Up @@ -643,8 +643,8 @@ std::vector<at::Tensor> mha_fwd(

// Set the different scale values.
params.scale_softmax = softmax_scale;
bool use_sink = softmax_sink_.has_value();
params.sink_softmax = use_sink ? softmax_sink_.value().data_ptr() : nullptr;
bool use_sink = sinks_.has_value();
params.sink_softmax = use_sink ? sinks_.value().data_ptr() : nullptr;

params.softcap = softcap;

Expand Down
2 changes: 1 addition & 1 deletion src/torch_extension_sycl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" Tensor? softmax_sink,"
" Tensor? sinks,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
Expand Down
14 changes: 7 additions & 7 deletions tests/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def generate_qkv(
# )
@pytest.mark.parametrize("causal,local", [(False, True), (False, False), (True, False)])
# @pytest.mark.parametrize("causal,local", [(True, False)])
@pytest.mark.parametrize("use_softmax_sink", [True, False])
@pytest.mark.parametrize("use_sinks", [True, False])
# @pytest.mark.parametrize(
# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
# )
Expand Down Expand Up @@ -556,7 +556,7 @@ def test_flash_attn_kvcache(
seqlen_new_eq_seqlen_q,
causal,
local,
use_softmax_sink,
use_sinks,
new_kv,
mha_type,
dtype,
Expand Down Expand Up @@ -586,8 +586,8 @@ def test_flash_attn_kvcache(
assert nheads % nheads_k == 0
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
if use_softmax_sink:
softmax_sink = torch.randn(nheads, device=device, dtype=dtype_ref)
if use_sinks:
sinks = torch.randn(nheads, device=device, dtype=dtype_ref)
if dtype == torch.float8_e4m3fn or not is_hopper():
# for fp8 and ampere arch, we not support v head dim != qk head dim
dv_vals = [d]
Expand Down Expand Up @@ -831,7 +831,7 @@ def test_flash_attn_kvcache(
k_cache_rep,
v_cache_rep,
softmax_scale,
softmax_sink if use_softmax_sink else None,
sinks if use_sinks else None,
query_padding_mask,
key_padding_mask,
causal=causal,
Expand All @@ -844,7 +844,7 @@ def test_flash_attn_kvcache(
k_cache_rep,
v_cache_rep,
softmax_scale,
softmax_sink if use_softmax_sink else None,
sinks if use_sinks else None,
query_padding_mask,
key_padding_mask,
causal=causal,
Expand Down Expand Up @@ -905,7 +905,7 @@ def test_flash_attn_kvcache(
causal=causal,
window_size=window_size,
softmax_scale=softmax_scale,
softmax_sink=softmax_sink if use_softmax_sink else None,
sinks=sinks if use_sinks else None,
rotary_interleaved=rotary_interleaved,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
Expand Down