Skip to content

Commit cbff76f

Browse files
authored
Add block size table for ragged_paged_attention (#8942)
1 parent fb6038d commit cbff76f

File tree

3 files changed

+91
-29
lines changed

3 files changed

+91
-29
lines changed

test/test_pallas.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -751,12 +751,15 @@ def ragged_paged_attention_wrapper(
751751
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)
752752

753753
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
754-
from torch_xla.experimental.custom_kernel import _get_default_ragged_paged_attention_block_size
754+
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
755755
if num_kv_pages_per_block is None:
756756
assert num_queries_per_block is None
757757
token_num = q.shape[0]
758-
num_kv_pages_per_block, num_queries_per_block = _get_default_ragged_paged_attention_block_size(
759-
token_num)
758+
token_num, q_head_num, _ = q.shape
759+
kv_head_num = kv_pages[2] // 2
760+
max_model_len = 2048
761+
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
762+
q_head_num, kv_head_num, token_num, max_model_len)
760763
jax_kernel_output = torch.from_numpy(
761764
np.array(
762765
jax_ragged_paged_attention(

torch_xla/experimental/custom_kernel.py

+6-26
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_xla.distributed.spmd import Mesh
1010
import torch_xla.distributed.spmd as xs
1111
from torch_xla._internal.jax_workarounds import requires_jax
12+
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
1213

1314
# Re-expose this API used that is referenced by docs
1415
from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import
@@ -915,29 +916,6 @@ def _ragged_paged_attention_nonkernel(
915916
return torch.cat(outputs, dim=0)
916917

917918

918-
def _get_default_ragged_paged_attention_block_size(token_num):
919-
tpu_version = torch_xla.tpu.version()
920-
if tpu_version < 4:
921-
raise NotImplementedError("TPU version must be 4 or higher.")
922-
if tpu_version == 4:
923-
# This default block size is not tuned, only make sure there's no
924-
# OOM in vmem
925-
num_kv_pages_per_block = 16
926-
num_queries_per_block = 128
927-
return num_kv_pages_per_block, num_queries_per_block
928-
929-
# This heristic is based on the initial kernel micro benchmarking:
930-
# When the token_num is small, there's no long request of prefill.
931-
# While when it's larger, the block size is adjusted for it.
932-
if token_num <= 128:
933-
num_kv_pages_per_block = 128
934-
num_queries_per_block = 32
935-
else:
936-
num_kv_pages_per_block = 128
937-
num_queries_per_block = 96
938-
return num_kv_pages_per_block, num_queries_per_block
939-
940-
941919
@requires_jax
942920
def ragged_paged_attention(
943921
q, # [max_num_batched_tokens, num_q_heads, head_dim]
@@ -952,6 +930,7 @@ def ragged_paged_attention(
952930
soft_cap: float | None = None,
953931
mask_value=None,
954932
use_kernel=True,
933+
max_model_len=2048, # Used as a hint for the kernel block sizes selection
955934
# kernel tuning parameters
956935
num_kv_pages_per_block=None,
957936
num_queries_per_block=None,
@@ -980,9 +959,10 @@ def ragged_paged_attention(
980959

981960
if num_kv_pages_per_block is None:
982961
assert num_queries_per_block is None
983-
token_num = q.shape[0]
984-
num_kv_pages_per_block, num_queries_per_block = _get_default_ragged_paged_attention_block_size(
985-
token_num)
962+
token_num, q_head_num, _ = q.shape
963+
kv_head_num = kv_pages[2] // 2
964+
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
965+
q_head_num, kv_head_num, token_num, max_model_len)
986966

987967
if vmem_limit_bytes is None:
988968
vmem_limit_bytes = 64 * 1024 * 1024
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch_xla
2+
3+
4+
def _next_power_of_2_bit_manipulation(x):
5+
"""
6+
Finds the smallest power of 2 >= x using bit manipulation.
7+
Assumes x is an integer.
8+
9+
Args:
10+
x: The input number (should be an integer).
11+
12+
Returns:
13+
The smallest integer power of 2 that is >= x.
14+
Returns 1 if x <= 0.
15+
"""
16+
if x <= 0:
17+
return 1
18+
if x == 1:
19+
return 1
20+
return 1 << (x - 1).bit_length()
21+
22+
23+
# ragged_paged_attention
24+
# key: (q_head_num, kv_head_num, token_num, max_model_len)
25+
# value: (num_kv_pages_per_block, num_queries_per_block)
26+
27+
28+
def _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num,
29+
max_model_len):
30+
token_num = _next_power_of_2_bit_manipulation(token_num)
31+
max_model_len = _next_power_of_2_bit_manipulation(max_model_len)
32+
return q_head_num, kv_head_num, token_num, max_model_len
33+
34+
35+
# TODO: add more tuned block sizes in the table
36+
_ragged_attention_table = {
37+
(32, 8, 4096, 2048): (128, 64),
38+
(4, 1, 4096, 2048): (128, 128),
39+
(32, 8, 2048, 2048): (128, 32),
40+
(4, 1, 2048, 2048): (128, 64),
41+
(32, 8, 1024, 2048): (64, 32),
42+
(1, 1, 1024, 2048): (64, 32),
43+
(32, 8, 4096, 4096): (128, 64),
44+
(4, 1, 4096, 4096): (128, 128),
45+
(32, 8, 2048, 4096): (128, 32),
46+
(4, 1, 2048, 4096): (128, 64),
47+
(32, 8, 1024, 4096): (64, 32),
48+
(1, 1, 1024, 4096): (64, 32),
49+
(32, 8, 4096, 64): (32, 32),
50+
(4, 1, 4096, 64): (32, 32),
51+
(32, 8, 2048, 64): (32, 32),
52+
(4, 1, 2048, 64): (32, 32),
53+
(32, 8, 1024, 64): (32, 32),
54+
(1, 1, 1024, 64): (32, 32),
55+
(32, 8, 4096, 128): (32, 32),
56+
(4, 1, 4096, 128): (32, 32),
57+
(32, 8, 2048, 128): (32, 32),
58+
(4, 1, 2048, 128): (32, 32),
59+
(32, 8, 1024, 128): (32, 32),
60+
(1, 1, 1024, 128): (32, 32),
61+
}
62+
63+
64+
def get_ragged_attention_tuned_block_size(q_head_num, kv_head_num, token_num,
65+
max_model_len):
66+
tpu_version = torch_xla.tpu.version()
67+
if tpu_version < 4:
68+
raise NotImplementedError("TPU version must be 4 or higher.")
69+
if tpu_version == 4:
70+
# This default block size is not tuned, only make sure there's no
71+
# OOM in vmem
72+
num_kv_pages_per_block = 16
73+
num_queries_per_block = 128
74+
return num_kv_pages_per_block, num_queries_per_block
75+
76+
key = _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num,
77+
max_model_len)
78+
block_sizes = _ragged_attention_table.get(key, (128, 32))
79+
return block_sizes

0 commit comments

Comments
 (0)