diff --git a/benchmark/bench_flash_attn.py b/benchmark/bench_flash_attn.py index fae410b..c3d5c6a 100644 --- a/benchmark/bench_flash_attn.py +++ b/benchmark/bench_flash_attn.py @@ -12,7 +12,7 @@ def flash_attn_baseline( causal, window_size, softmax_scale, - softmax_sink, + sinks, cache_seqlens, page_table, cu_seqlens_q, @@ -24,7 +24,7 @@ def flash_attn_baseline( k_cache, v_cache, causal=causal, - softmax_sink=softmax_sink, + sinks=sinks, window_size=window_size, softmax_scale=softmax_scale, page_table=page_table, @@ -39,7 +39,7 @@ def flash_attn_baseline( # Benchmark configurations causal = [True, False] local = [True, False] -use_softmax_sink = [True, False] +use_sinks = [True, False] batch_size = [1, 16] q_seq_length_range = [1, 512, 1024] kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384] @@ -50,7 +50,7 @@ def flash_attn_baseline( product( causal, local, - use_softmax_sink, + use_sinks, batch_size, q_seq_length_range, kv_seq_length_range, @@ -65,7 +65,7 @@ def flash_attn_baseline( x_names=[ "causal", "local", - "use_softmax_sink", + "use_sinks", "batch_size", "q_seq_length", "kv_seq_length", @@ -84,7 +84,7 @@ def flash_attn_baseline( def benchmark( causal, local, - use_softmax_sink, + use_sinks, batch_size, q_seq_length, kv_seq_length, @@ -127,9 +127,7 @@ def benchmark( max_seqlen_q = q_seq_length window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,)) - softmax_sink = ( - torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None - ) + sinks = torch.randn(num_heads, device=device, dtype=dtype) if use_sinks else None softmax_scale = 1.0 / (head_dim**0.5) @@ -144,7 +142,7 @@ def benchmark( causal=causal, window_size=window_size, softmax_scale=softmax_scale, - softmax_sink=softmax_sink, + sinks=sinks, cache_seqlens=cache_seqlens, page_table=page_table, cu_seqlens_q=cu_seqlens_q, diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index 846533d..6cede1c 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -62,7 +62,7 @@ std::vector mha_fwd( std::optional& k_descale_, // (b, h_k) std::optional& v_descale_, // (b, h_k) float const softmax_scale, - std::optional& softmax_sink, + std::optional& sinks, bool is_causal, int window_size_left, int window_size_right, diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 9bd1411..9beb7b1 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -49,7 +49,7 @@ def flash_attn_with_kvcache( k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, softmax_scale=None, - softmax_sink=None, + sinks=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated @@ -205,7 +205,7 @@ def flash_attn_with_kvcache( k_descale, v_descale, softmax_scale, - softmax_sink, + sinks, causal, window_size[0], window_size[1], diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 14e201e..05ae343 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -489,7 +489,7 @@ std::vector mha_fwd( std::optional& k_descale_, // (b, h_k) std::optional& v_descale_, // (b, h_k) const float softmax_scale_, - std::optional& softmax_sink_, + std::optional& sinks_, bool is_causal, int window_size_left, int window_size_right, @@ -643,8 +643,8 @@ std::vector mha_fwd( // Set the different scale values. params.scale_softmax = softmax_scale; - bool use_sink = softmax_sink_.has_value(); - params.sink_softmax = use_sink ? softmax_sink_.value().data_ptr() : nullptr; + bool use_sink = sinks_.has_value(); + params.sink_softmax = use_sink ? sinks_.value().data_ptr() : nullptr; params.softcap = softcap; diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 70ba202..16e8ded 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -82,7 +82,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor? k_descale," " Tensor? v_descale," " float softmax_scale," - " Tensor? softmax_sink," + " Tensor? sinks," " bool is_causal," " int window_size_left," " int window_size_right," diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index cfed227..3b03b53 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -489,7 +489,7 @@ def generate_qkv( # ) @pytest.mark.parametrize("causal,local", [(False, True), (False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(True, False)]) -@pytest.mark.parametrize("use_softmax_sink", [True, False]) +@pytest.mark.parametrize("use_sinks", [True, False]) # @pytest.mark.parametrize( # "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] # ) @@ -556,7 +556,7 @@ def test_flash_attn_kvcache( seqlen_new_eq_seqlen_q, causal, local, - use_softmax_sink, + use_sinks, new_kv, mha_type, dtype, @@ -586,8 +586,8 @@ def test_flash_attn_kvcache( assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if use_softmax_sink: - softmax_sink = torch.randn(nheads, device=device, dtype=dtype_ref) + if use_sinks: + sinks = torch.randn(nheads, device=device, dtype=dtype_ref) if dtype == torch.float8_e4m3fn or not is_hopper(): # for fp8 and ampere arch, we not support v head dim != qk head dim dv_vals = [d] @@ -831,7 +831,7 @@ def test_flash_attn_kvcache( k_cache_rep, v_cache_rep, softmax_scale, - softmax_sink if use_softmax_sink else None, + sinks if use_sinks else None, query_padding_mask, key_padding_mask, causal=causal, @@ -844,7 +844,7 @@ def test_flash_attn_kvcache( k_cache_rep, v_cache_rep, softmax_scale, - softmax_sink if use_softmax_sink else None, + sinks if use_sinks else None, query_padding_mask, key_padding_mask, causal=causal, @@ -905,7 +905,7 @@ def test_flash_attn_kvcache( causal=causal, window_size=window_size, softmax_scale=softmax_scale, - softmax_sink=softmax_sink if use_softmax_sink else None, + sinks=sinks if use_sinks else None, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits,