From 1dabbd661e575940d76415ca74cecac5a360f6dd Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 18 Mar 2025 21:56:54 -0700 Subject: [PATCH] [ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head (#8851) --- test/test_pallas.py | 4 +- torch_xla/experimental/custom_kernel.py | 80 +------ .../ragged_paged_attention_v2.py | 208 +++++++++--------- 3 files changed, 116 insertions(+), 176 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index d8d4eaf5b37..6ba879ff76f 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -135,9 +135,9 @@ def _ragged_pagedattention_generate_qkv( "constant", 0) q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim), dtype=dtype) - k_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim), + k_pages = torch.randn((num_pages, page_size, num_kv_heads * head_dim), dtype=dtype) - v_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim), + v_pages = torch.randn((num_pages, page_size, num_kv_heads * head_dim), dtype=dtype) page_indices = torch.randint( 0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index eac28b15493..b131d5807b0 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -893,70 +893,10 @@ def flash_attention( sm_scale, ab, partition_spec, mesh) -def ceil_div(a, b): - assert b != 0 - return (a + b - 1) // b - - -def validate_ragged_paged_attention_inputs( - q, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] - kv_lens, # i32[max_num_seqs] - page_indices, # i32[max_num_seqs, pages_per_seq] - cu_q_lens, # i32[max_num_seqs + 1] - num_seqs, # i32[1] -): - _, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape - max_num_seqs, _ = page_indices.shape - if k_pages.shape != v_pages.shape: - raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.") - if head_dim_k != head_dim: - raise ValueError( - f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") - if kv_lens.shape != (max_num_seqs,): - raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" - " `max_num_seqs` is `page_indices.shape[0]`.") - if cu_q_lens.shape != (max_num_seqs + 1,): - raise ValueError( - f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" - " `max_num_seqs` is `page_indices.shape[0]`.") - if (kv_lens.dtype != torch.int32 or page_indices.dtype != torch.int32 or - cu_q_lens.dtype != torch.int32): - raise ValueError( - "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" - f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," - f" {cu_q_lens.dtype=}.") - if num_q_heads % num_kv_heads != 0: - raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") - - # Must check below on runtime! - # if num_seqs > max_num_seqs: - # raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}") - # max_kv_len = torch.max(kv_lens) - # min_pages_per_seq = ceil_div(max_kv_len, page_size) - # if pages_per_seq < min_pages_per_seq: - # raise ValueError( - # f"{pages_per_seq=} must be greater or equal to" - # f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") - # if cu_q_lens[num_seqs] > max_num_batched_tokens: - # raise ValueError( - # f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to" - # f" {max_num_batched_tokens=}.") - # for i in range(num_seqs): - # q_len = cu_q_lens[i + 1] - cu_q_lens[i] - # kv_len = kv_lens[i] - # if q_len > kv_len: - # raise ValueError( - # f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") - - def _ragged_paged_attention_nonkernel( queries, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + k_pages, # [total_num_pages, page_size, num_kv_heads * head_dim] + v_pages, # [total_num_pages, page_size, num_kv_heads * head_dim] kv_lens, # i32[max_num_seqs] page_indices, # i32[max_num_seqs, pages_per_seq] cu_q_lens, # i32[max_num_seqs + 1] @@ -965,8 +905,9 @@ def _ragged_paged_attention_nonkernel( sm_scale=1.0, mask_value=DEFAULT_MASK_VALUE, ): - _, _, num_kv_heads, head_dim = k_pages.shape - num_q_heads = queries.shape[1] + _, num_q_heads, head_dim = queries.shape + _, _, kv_model_dim = k_pages.shape + num_kv_heads = kv_model_dim // head_dim assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads outputs = [] @@ -977,8 +918,8 @@ def _ragged_paged_attention_nonkernel( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = k_pages[indices, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v = v_pages[indices, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] k = torch.repeat_interleave(k, num_query_per_kv, dim=1) v = torch.repeat_interleave(v, num_query_per_kv, dim=1) attn = torch.einsum("qhd,khd->hqk", q, k) @@ -998,8 +939,8 @@ def _ragged_paged_attention_nonkernel( @requires_jax def ragged_paged_attention( q, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + k_pages, # [total_num_pages, page_size, num_kv_heads * head_dim] + v_pages, # [total_num_pages, page_size, num_kv_heads * head_dim] kv_lens, # i32[max_num_seqs] page_indices, # i32[max_num_seqs, pages_per_seq] cu_q_lens, # i32[max_num_seqs + 1] @@ -1014,8 +955,7 @@ def ragged_paged_attention( ): if mask_value is None: mask_value = DEFAULT_MASK_VALUE - validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens, - page_indices, cu_q_lens, num_seqs) + if not use_kernel: return _ragged_paged_attention_nonkernel( q, diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index f981445ac96..feb77690b5a 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -1,3 +1,16 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """TPU-Friendly Ragged Paged Attention kernel. This kernel offers a highly optimized implementation of ragged paged attention, @@ -15,9 +28,6 @@ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) -# TODO(jevinjiang): importing kernel from pltpu ops directly. No need -# to keep duplicated implementations. - class MultiPageAsyncCopyDescriptor: """Descriptor for async copy of multiple K/V pages from HBM.""" @@ -27,7 +37,7 @@ def __init__( pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] sem, - page_indices_ref, # i32[num_seqs, pages_per_seq] + page_indices_ref, # i32[max_num_seqs, pages_per_seq] offset, # [seq_idx, kv_pages_start] ): self._vmem_buf = vmem_buf @@ -60,30 +70,33 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - kv_lens: jax.Array, # i32[num_seqs] - page_indices: jax.Array, # i32[num_seqs, pages_per_seq] - cu_q_lens: jax.Array, # i32[num_seqs + 1] - num_seqs: int, + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads * head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads * head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1], *, sm_scale: float = 1.0, mask_value: float = DEFAULT_MASK_VALUE, ): - _, _, num_kv_heads, head_dim = k_pages.shape - num_q_heads = queries.shape[1] + check_inputs_shapes(queries, k_pages, v_pages, kv_lens, page_indices, + cu_q_lens, num_seqs) + _, num_q_heads, head_dim = queries.shape + _, _, kv_hidden_size = k_pages.shape + num_kv_heads = kv_hidden_size // head_dim assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads outputs = [] - for i in range(num_seqs): + for i in range(num_seqs[0]): q_start = cu_q_lens[i] q_end = cu_q_lens[i + 1] q_len = q_end - q_start kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = k_pages[indices, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v = v_pages[indices, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -107,25 +120,26 @@ def validate_inputs_on_runtime( kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32 + num_seqs, # i32[1] ): - check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) + check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs) max_num_batched_tokens = q.shape[0] page_size = k_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape - if num_seqs > max_num_seqs: - raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}") + if num_seqs[0] > max_num_seqs: + raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) + min_pages_per_seq = cdiv(max_kv_len, page_size) if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") - if cu_q_lens[num_seqs] > max_num_batched_tokens: + if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: raise ValueError( - f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to" + f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" f" {max_num_batched_tokens=}.") - for i in range(num_seqs): + for i in range(num_seqs[0]): q_len = cu_q_lens[i + 1] - cu_q_lens[i] kv_len = kv_lens[i] if q_len > kv_len: @@ -136,21 +150,29 @@ def validate_inputs_on_runtime( # Expect to run these checks during compile time. def check_inputs_shapes( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads * head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads * head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32[1] ): - max_num_batched_tokens, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape - max_num_seqs, _ = page_indices.shape + _, num_q_heads, head_dim = q.shape + if head_dim != 128: + raise NotImplementedError(f"Only support head_dim=128, got {head_dim=}") if k_pages.shape != v_pages.shape: raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.") - if head_dim_k != head_dim: - raise ValueError( - f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") + f"Expected {k_pages.shape=} to be equal to {v_pages.shape=}.") + _, page_size, kv_hidden_size = k_pages.shape + kv_packing = get_dtype_packing(k_pages.dtype) + if page_size % kv_packing != 0: + raise ValueError(f"Expected {page_size=} is divisible by {kv_packing=}") + if kv_hidden_size % head_dim != 0: + raise ValueError(f"Expected {kv_hidden_size=} is divisible by {head_dim=}.") + num_kv_heads = kv_hidden_size // head_dim + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"Expected {num_q_heads=} is divisible by {num_kv_heads=}") + max_num_seqs, _ = page_indices.shape if kv_lens.shape != (max_num_seqs,): raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" " `max_num_seqs` is `page_indices.shape[0]`.") @@ -158,14 +180,14 @@ def check_inputs_shapes( raise ValueError( f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" " `max_num_seqs` is `page_indices.shape[0]`.") + if num_seqs.shape != (1,): + raise ValueError(f"Expected {num_seqs.shape=} is (1,)") if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or cu_q_lens.dtype != jnp.int32): raise ValueError( - "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" + "Expected the dtypes of `kv_lens`, `page_indices`, and `cu_q_lens` are" f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," f" {cu_q_lens.dtype=}.") - if num_q_heads % num_kv_heads != 0: - raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") def ragged_paged_attention_kernel( @@ -178,13 +200,13 @@ def ragged_paged_attention_kernel( num_seqs_ref, # Input q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads * head_dim] + v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads * head_dim] # Output o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] # Scratch - k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] - v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk * head_dim] + v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk * head_dim] sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] @@ -194,7 +216,8 @@ def ragged_paged_attention_kernel( ): num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + _, num_kv_pages_per_blk, page_size, buf_model_dim = k_bufs.shape + num_kv_heads_per_blk = buf_model_dim // head_dim num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk heads_blk_idx, q_blk_idx = ( @@ -213,7 +236,8 @@ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, heads_start = heads_blk_idx * num_kv_heads_per_blk async_copy_k = MultiPageAsyncCopyDescriptor( k_pages_hbm_ref.at[:, :, - pl.ds(heads_start, num_kv_heads_per_blk), :], + pl.ds(heads_start * head_dim, num_kv_heads_per_blk * + head_dim)], k_bufs.at[buf_idx], sems.at[buf_idx, 0], page_indices_ref, @@ -221,7 +245,8 @@ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, ) async_copy_v = MultiPageAsyncCopyDescriptor( v_pages_hbm_ref.at[:, :, - pl.ds(heads_start, num_kv_heads_per_blk), :], + pl.ds(heads_start * head_dim, num_kv_heads_per_blk * + head_dim)], v_bufs.at[buf_idx], sems.at[buf_idx, 1], page_indices_ref, @@ -229,25 +254,6 @@ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, ) return async_copy_k, async_copy_v - # TODO(jevinjiang): Add these to Mosaic: - # 1. Support arbitrary strided load/store for any dtype. - # 2. Support arbitrary strided load/store for any last dimension. - def strided_load_kv(ref, start, step): - if ref.dtype == jnp.float32: - return ref[start::step, :] - packing = get_dtype_packing(ref.dtype) - assert ref.dtype == jnp.bfloat16 - assert step % packing == 0 - b_start = start // packing - b_offset = start % packing - b_step = step // packing - b_ref = ref.bitcast(jnp.int32) - b = b_ref[b_start::b_step, :] - bw = 32 // packing - b = jnp.right_shift(b, bw * b_offset) - b = jnp.left_shift(b, bw * (packing - 1)) - return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) - def fold_on_2nd_minor(vec): assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 assert len(vec.shape) >= 2 @@ -460,8 +466,8 @@ def prefetch_next_kv_blk(): cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx) kv_to_load_shape = ( - num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, - head_dim, + num_kv_pages_per_blk * page_size, + num_kv_heads_per_blk * head_dim, ) k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) @@ -471,8 +477,9 @@ def prefetch_next_kv_blk(): # unaligned position! q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :]) - k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) - v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + k = k_ref[:, kv_head_idx * head_dim:(kv_head_idx + 1) * head_dim] + v = v_ref[:, kv_head_idx * head_dim:(kv_head_idx + 1) * head_dim] + # TODO(jevinjiang): resolve spill issue! flash_attention( q, k, @@ -487,7 +494,7 @@ def prefetch_next_kv_blk(): _, next_buf_idx = lax.while_loop( is_valid_kv_blk_in_cur_seq, compute_with_kv_blk_in_cur_seq, - (0, cur_buf_idx), + (0, cur_buf_idx), # (kv_blk_idx, buf_idx) ) next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx) done = lax.select(q_end < q_len_end, done, 1) @@ -496,14 +503,14 @@ def prefetch_next_kv_blk(): _, seq_idx, buf_idx = lax.while_loop( is_cur_q_blk_needed, compute_with_cur_q_blk, - (0, init_seq_idx, init_buf_idx), + (0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx) ) # Reset seq_idx for next kv_heads_blk if run out of seqs! seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) seq_buf_idx_ref[1] = buf_idx -def ceil_div(a, b): +def cdiv(a, b): assert b != 0 return (a + b - 1) // b @@ -520,31 +527,22 @@ def get_dtype_packing(dtype): raise ValueError(f"Not implemented: unsupported {dtype=}") -def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): - q_packing = get_dtype_packing(q_dtype) - kv_packing = get_dtype_packing(kv_dtype) +def get_min_q_heads_per_blk(num_q_heads, q_dtype, num_q_heads_per_kv_head): - def can_be_xla_fully_tiled(x, packing): - if x % packing != 0: - return False - x //= packing - return x in (1, 2, 4, 8) or x % 8 == 0 + def gcd(a, b): + while b: + a, b = b, a % b + return a - # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): - raise ValueError( - f"Not implemented: {num_kv_heads=} can not be XLA fully tiled.") - assert num_q_heads % num_kv_heads == 0 - ratio = num_q_heads // num_kv_heads - # TODO(jevinjiang): we can choose smaller tiling for packed type if large - # second minor tiling is not on. - max_kv_tiling = 8 * kv_packing - min_kv_heads = ( - max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads) - min_q_heads = min_kv_heads * ratio - if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_kv_heads - return num_q_heads, num_kv_heads + def lcm(a, b): + return a * b // gcd(a, b) + + q_packing = get_dtype_packing(q_dtype) + max_q_tiling = 8 * q_packing + min_q_heads = lcm(max_q_tiling, num_q_heads_per_kv_head) + if num_q_heads % min_q_heads == 0: + return min_q_heads + return num_q_heads @functools.partial( @@ -559,13 +557,12 @@ def can_be_xla_fully_tiled(x, packing): ) def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - # TODO(jevinjiang): create a write_to_kv_cache kernel! - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads * head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads * head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] *, sm_scale: float = 1.0, mask_value: float = DEFAULT_MASK_VALUE, @@ -596,16 +593,19 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) + check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs) num_q, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_pages.shape + _, page_size, kv_hidden_size = k_pages.shape + num_kv_heads = kv_hidden_size // head_dim num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(num_q, num_q_per_blk) - num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_kv_heads, q.dtype, k_pages.dtype) + num_q_blks = cdiv(num_q, num_q_per_blk) + num_q_heads_per_blk = get_min_q_heads_per_blk(num_q_heads, q.dtype, + num_q_heads_per_kv_head) assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 + num_kv_heads_per_blk = num_q_heads_per_blk // num_q_heads_per_kv_head num_heads_blks = num_q_heads // num_q_heads_per_blk grid = (num_heads_blks, num_q_blks) @@ -633,15 +633,14 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): 2, # For double buffering during DMA copies. num_kv_pages_per_blk, page_size, - num_kv_heads_per_blk, - head_dim, + num_kv_heads_per_blk * head_dim, ), k_pages.dtype, ) scratch_shapes = [ double_buf_scratch, # k_bufs double_buf_scratch, # v_bufs - pltpu.SemaphoreType.DMA((2, 2)), + pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] lm_scratch, # l_ref lm_scratch, # m_ref ] @@ -650,7 +649,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): page_indices, cu_q_lens, jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx - num_seqs) + num_seqs, + ) kernel = pl.pallas_call( functools.partial( ragged_paged_attention_kernel,