You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Figure 8: Single request GQA decode performance, use llama2-70b setting: tp=2, num_kv_heads=4, num_qo_heads=32, head_dim=128. Sequence length varies from 32 to 8192.
164
+
Figure 8: Single request GQA decode performance, use llama2-70b setting: tp=2, num_kv_heads=4, num_qo_heads=32, head_dim=128. Sequence length varies from 32 to 65536.
165
165
</p>
166
166
167
167
For single-request GQA decoding attention, FlashInfer (Tensor Cores) achieves better performance than FlashAttention 2.4.2 on both A100 & H100, and FlashInfer (CUDA Cores) can only achieve 40%+ bandwidth utilization because of limited CUDA Cores performance.
@@ -196,7 +196,7 @@ FlashInfer implements high-performance fp8 decode decode kernels, which could ac
Figure 3. Different order to merge attention states are mathematically equivalent.
67
+
Figure 2. Different order to merge attention states are mathematically equivalent.
68
68
</p>
69
69
70
70
Recursive Attention allow us to decompose attention computation into multiple stages, different stages
@@ -78,12 +78,12 @@ we propose the following Divide-and-Conquer algorithm:
78
78
2. Use batch decode attention kernel to compute the attention state between queries and KV-Cache of unique suffixes.
79
79
3. Use merge operator to combine two attention states to get the final attention output.
80
80
81
-
The overall workflow is explained on the left side of Figure 2, different color of rectangles are processed in different thread blocks in GPU. Note that for multi-query attention kernels, we access KV-Cache through SMEM or registers and for decode kernels we can only access KV-Cache through L2 Cache or Global Memory. Cascade Inference allow us to maximize memory reuse for common prefix, thus making the attention computation much more memory efficient.
81
+
The overall workflow is explained on the left side of Figure 3, different color of rectangles are processed in different thread blocks in GPU. Note that for multi-query attention kernels, we access KV-Cache through SMEM or registers and for decode kernels we can only access KV-Cache through L2 Cache or Global Memory. Cascade Inference allow us to maximize memory reuse for common prefix, thus making the attention computation much more memory efficient.
Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <ahref="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a>
86
+
Figure 3. Workflow of Cascade Inference, throughput values adapted from blog: <ahref="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a>
87
87
</p>
88
88
89
89
We call the divide-and-conquer approach for shared-prefix attention the "Cascade Inference".
@@ -95,16 +95,16 @@ We evaluate Cascade Inference on H100 SXM 80GB and A100 PCIE 80GB GPUs. The inpu
Figure 4. Speedup over vLLM PageAttention on A100 PCIe 80GB
104
+
Figure 5. Speedup over vLLM PageAttention on A100 PCIe 80GB
105
105
</p>
106
106
107
-
Figure 3 and 4 show the normalized performance on FlashInfer kernels in cascading and non-cascading setting
107
+
Figure 4 and 5 show the normalized performance on FlashInfer kernels in cascading and non-cascading setting
108
108
over vLLM implementation. FlashInfer kernels in both settings outperforms vLLM kernels, and cascading kernels significant speedup over non-Cascade Inference kernels in most cases.
109
109
The benefit of cascade inference increases as shared prefix length and batch size grows (where the prefill kernel dominates execution time) and decreases as we increase unique suffix length (where the batch decode kernel dominates execution time). For very long shared prompt (32768), the decode kernel can get up to 31x speedup on H100 SXM 80GB with large batch size(≥128) and short unique kv-length (≤256).
0 commit comments