File tree Expand file tree Collapse file tree 1 file changed +31
-0
lines changed
tritonbench/operators/decoding_attention Expand file tree Collapse file tree 1 file changed +31
-0
lines changed Original file line number Diff line number Diff line change 7272 HAS_AITER = False
7373
7474
75+ # [Optional] flash_fwd cute-DSL backend
76+ HAS_FLASH_CUTE = True
77+ try :
78+ from flash_attn .cute .interface import (
79+ flash_attn_func as flash_attn_cute_func
80+ )
81+ except (ImportError , IOError , AttributeError ):
82+ HAS_FLASH_CUTE = False
83+ flash_attn_cute_func = None # Define it as None to avoid NameError
84+
85+
7586def parse_op_args (args : List [str ]):
7687 parser = argparse .ArgumentParser ()
7788 parser .add_argument ("--batch" , type = int , help = "Batch size" )
@@ -559,6 +570,26 @@ def fbgemm_gqa_fp8kv(
559570 cache_logical_dtype_int = 1 , # FP8 = 1
560571 )
561572
573+
574+ @register_benchmark (enabled = HAS_FLASH_CUTE )
575+ def flash_cute_dsl (
576+ self ,
577+ q : torch .Tensor ,
578+ k_cache : torch .Tensor ,
579+ v_cache : torch .Tensor ,
580+ cache_seqlens : torch .Tensor ,
581+ ) -> Callable :
582+ """Flash Attention implementation using cute-DSL backend."""
583+ # For GQA, cute-DSL handles the head expansion internally
584+ # We pass the original KV tensors without manual expansion
585+ q_heads = q .shape [2 ]
586+ kv_heads = k_cache .shape [2 ]
587+ return lambda :flash_attn_cute_func (
588+ q , k_cache , v_cache ,
589+ causal = CAUSAL ,
590+ pack_gqa = (q_heads != kv_heads )
591+ )
592+
562593 @register_benchmark (enabled = HAS_AITER )
563594 def aiter_paged_fp8kv (
564595 self ,
You can’t perform that action at this time.
0 commit comments