Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ void GetBlockShapeAndSplitKVBlock(
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
int max_kv_len_this_time = max_len_cpu_ptr[8];

const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0];

// decoder
if (max_dec_len_this_time > 0) {
const bool mla_backend = checkAttentionBackend();
Expand Down Expand Up @@ -343,25 +345,15 @@ void GetBlockShapeAndSplitKVBlock(
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];

// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where *
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
// seq_lens_decoder.dims()[0] * 1024;

const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;

PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));

split_block_for_mla<<<1, 32, 0, stream>>>(
Expand All @@ -374,22 +366,15 @@ void GetBlockShapeAndSplitKVBlock(
chunk_size);

} else {
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
// should be taken here
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;

PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
Expand All @@ -413,13 +398,6 @@ void GetBlockShapeAndSplitKVBlock(
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}

// encoder
Expand Down
Loading