Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,14 @@ __global__ void multi_query_append_attention_kernel(
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16

const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_len,
q_ori_n_stride,
HEAD_DIM);
commit_group();
Expand Down Expand Up @@ -486,11 +484,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;

// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -569,9 +562,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, tid / 16); // 16 * 16

const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里游泳一些边界case测试下offset确实不会超过div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)吗

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里游泳一些边界case测试下offset确实不会超过div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)吗

这里因为每个CTA 最多只读 num_rows_per_block 个Q head_dim,所以只需要检查不超过q_len即可


#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
Expand All @@ -583,7 +573,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
T>(q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_len,
q_ori_n_stride,
HEAD_DIM);
commit_group();
Expand All @@ -593,9 +583,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
q_smem_inplace_multiply_sm_scale_multi_warps<num_frags_x, num_frags_y, T>(
&qo_smem, scale);

smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM *
sizeof(T));
static_assert(num_rows_per_block == num_frags_x * 16);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_rows_per_block应该等于NUM_WARP_Q * num_frags_x * 16(tensor core的一个mma m维),这里因为原本NUM_WARP_Q等于1做了省略,assert的话可以加上

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NUM_WARP_Q == 1的assert 在函数开头加上了哈

static_assert(BLOCK_SIZE == NUM_WARP_KV * num_frags_z * 16);
smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)),
v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T));

const uint32_t num_iterations = div_up(
CAUSAL
Expand All @@ -605,13 +596,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE),
chunk_start)))
: chunk_len,
NUM_WARP_KV * num_frags_z * 16);
BLOCK_SIZE);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(kv_len - q_len, chunk_start)))
: mask_offset ? 0
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
(BLOCK_SIZE);

uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
Expand Down Expand Up @@ -698,7 +689,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();

kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
kv_idx_base += BLOCK_SIZE;
block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]);
if (block_id < 0) {
block_id = 0;
Expand Down
Loading