@@ -489,7 +489,7 @@ def generate_qkv(
489489# )
490490@pytest .mark .parametrize ("causal,local" , [(False , True ), (False , False ), (True , False )])
491491# @pytest.mark.parametrize("causal,local", [(True, False)])
492- @pytest .mark .parametrize ("use_softmax_sink " , [True , False ])
492+ @pytest .mark .parametrize ("use_sinks " , [True , False ])
493493# @pytest.mark.parametrize(
494494# "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
495495# )
@@ -556,7 +556,7 @@ def test_flash_attn_kvcache(
556556 seqlen_new_eq_seqlen_q ,
557557 causal ,
558558 local ,
559- use_softmax_sink ,
559+ use_sinks ,
560560 new_kv ,
561561 mha_type ,
562562 dtype ,
@@ -586,8 +586,8 @@ def test_flash_attn_kvcache(
586586 assert nheads % nheads_k == 0
587587 dtype_ref = torch .bfloat16 if dtype == torch .float8_e4m3fn else dtype
588588 dv_vals = [128 , d ] if d > 128 and d <= 192 else ([256 , 512 , d ] if d <= 64 else [d ])
589- if use_softmax_sink :
590- softmax_sink = torch .randn (nheads , device = device , dtype = dtype_ref )
589+ if use_sinks :
590+ sinks = torch .randn (nheads , device = device , dtype = dtype_ref )
591591 if dtype == torch .float8_e4m3fn or not is_hopper ():
592592 # for fp8 and ampere arch, we not support v head dim != qk head dim
593593 dv_vals = [d ]
@@ -831,7 +831,7 @@ def test_flash_attn_kvcache(
831831 k_cache_rep ,
832832 v_cache_rep ,
833833 softmax_scale ,
834- softmax_sink if use_softmax_sink else None ,
834+ sinks if use_sinks else None ,
835835 query_padding_mask ,
836836 key_padding_mask ,
837837 causal = causal ,
@@ -844,7 +844,7 @@ def test_flash_attn_kvcache(
844844 k_cache_rep ,
845845 v_cache_rep ,
846846 softmax_scale ,
847- softmax_sink if use_softmax_sink else None ,
847+ sinks if use_sinks else None ,
848848 query_padding_mask ,
849849 key_padding_mask ,
850850 causal = causal ,
@@ -905,7 +905,7 @@ def test_flash_attn_kvcache(
905905 causal = causal ,
906906 window_size = window_size ,
907907 softmax_scale = softmax_scale ,
908- softmax_sink = softmax_sink if use_softmax_sink else None ,
908+ sinks = sinks if use_sinks else None ,
909909 rotary_interleaved = rotary_interleaved ,
910910 scheduler_metadata = scheduler_metadata ,
911911 num_splits = num_splits ,
0 commit comments