diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index d8c94dc5446..cf283a617d2 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -157,8 +157,6 @@ __global__ void multi_query_append_attention_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -166,7 +164,7 @@ __global__ void multi_query_append_attention_kernel( q_base_ptr, &qo_smem, q_base_seq_id_this_block, - q_end, + q_len, q_ori_n_stride, HEAD_DIM); commit_group(); @@ -486,11 +484,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - // When cudagraph capture prefill, may launch more gridDim.x - if (btid >= static_cast(num_blocks_x_cpu)) { - return; - } - const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return; @@ -569,9 +562,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -583,7 +573,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( T>(q_base_ptr, &qo_smem, q_base_seq_id_this_block, - q_end, + q_len, q_ori_n_stride, HEAD_DIM); commit_group(); @@ -593,9 +583,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel( q_smem_inplace_multiply_sm_scale_multi_warps( &qo_smem, scale); - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); + static_assert(num_rows_per_block == num_frags_x * 16); + static_assert(BLOCK_SIZE == NUM_WARP_KV * num_frags_z * 16); + smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); const uint32_t num_iterations = div_up( CAUSAL @@ -605,13 +596,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel( div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), chunk_start))) : chunk_len, - NUM_WARP_KV * num_frags_z * 16); + BLOCK_SIZE); const uint32_t mask_check_iteration = (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero(kv_len - q_len, chunk_start))) : mask_offset ? 0 : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); + (BLOCK_SIZE); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); @@ -698,7 +689,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + kv_idx_base += BLOCK_SIZE; block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0;