Skip to content

Commit 6fa3410

Browse files
[Others]get_block_shape_and_split_kv_block clean code (#5123)
1 parent af715db commit 6fa3410

File tree

12 files changed

+364
-355
lines changed

12 files changed

+364
-355
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 308 additions & 255 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder,
7979
max_lens[2] = total_max_len_decoder;
8080
max_lens[3] = total;
8181
max_lens[4] = total_just_dec;
82-
max_lens[8] = total_max_len_kv;
82+
max_lens[5] = total_max_len_kv;
8383
}
8484
}
8585

@@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock(
273273
const int encoder_block_shape_q,
274274
const int decoder_block_shape_q,
275275
const int group_size,
276-
const int block_size,
277-
const int decoder_step_token_num) {
276+
const int block_size) {
278277
auto stream = seq_lens_encoder.stream();
279278
int bsz = seq_lens_this_time.shape()[0];
280279

@@ -302,10 +301,9 @@ void GetBlockShapeAndSplitKVBlock(
302301
int max_dec_len_this_time = max_len_cpu_ptr[2];
303302
int max_enc_dec_len_this_time = max_len_cpu_ptr[3];
304303
int max_just_dec_len_this_time = max_len_cpu_ptr[4];
305-
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
306-
int max_system_len = max_len_cpu_ptr[6];
307-
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
308-
int max_kv_len_this_time = max_len_cpu_ptr[8];
304+
int max_kv_len_this_time = max_len_cpu_ptr[5];
305+
306+
const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0];
309307

310308
// decoder
311309
if (max_dec_len_this_time > 0) {
@@ -343,25 +341,15 @@ void GetBlockShapeAndSplitKVBlock(
343341
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
344342
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
345343

346-
// NOTE: (changwenbin) When using auto_chunk,
347-
// decode_max_tile_size must take into account the maximum case, where *
348-
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
349-
// seq_lens_decoder.dims()[0] * 1024;
350-
351-
const uint32_t decoder_max_tile_size_per_bs_q =
352-
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
353-
const uint32_t decoder_batch_shape =
354-
bsz * 1024 * decoder_max_tile_size_per_bs_q;
355-
356344
PADDLE_ENFORCE_GPU_SUCCESS(
357345
cudaMemsetAsync(decoder_batch_ids.data<int>(),
358346
0,
359-
decoder_batch_shape * sizeof(int32_t),
347+
decoder_batch_ele_num * sizeof(int32_t),
360348
stream));
361349
PADDLE_ENFORCE_GPU_SUCCESS(
362350
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
363351
0,
364-
decoder_batch_shape * sizeof(int32_t),
352+
decoder_batch_ele_num * sizeof(int32_t),
365353
stream));
366354

367355
split_block_for_mla<<<1, 32, 0, stream>>>(
@@ -374,22 +362,15 @@ void GetBlockShapeAndSplitKVBlock(
374362
chunk_size);
375363

376364
} else {
377-
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
378-
// should be taken here
379-
const uint32_t decoder_max_tile_size_per_bs_q =
380-
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
381-
const uint32_t decoder_batch_shape =
382-
bsz * 1024 * decoder_max_tile_size_per_bs_q;
383-
384365
PADDLE_ENFORCE_GPU_SUCCESS(
385366
cudaMemsetAsync(decoder_batch_ids.data<int>(),
386367
0,
387-
decoder_batch_shape * sizeof(int32_t),
368+
decoder_batch_ele_num * sizeof(int32_t),
388369
stream));
389370
PADDLE_ENFORCE_GPU_SUCCESS(
390371
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
391372
0,
392-
decoder_batch_shape * sizeof(int32_t),
373+
decoder_batch_ele_num * sizeof(int32_t),
393374
stream));
394375
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
395376
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
@@ -413,13 +394,6 @@ void GetBlockShapeAndSplitKVBlock(
413394
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
414395
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
415396
}
416-
} else {
417-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
418-
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
419-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
420-
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
421-
decoder_num_blocks_cpu.copy_(
422-
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
423397
}
424398

425399
// encoder
@@ -486,8 +460,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
486460
const int encoder_block_shape_q,
487461
const int decoder_block_shape_q,
488462
const int group_size,
489-
const int block_size,
490-
const int decoder_step_token_num) {
463+
const int block_size) {
491464
return {};
492465
}
493466

@@ -498,8 +471,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
498471
const int encoder_block_shape_q,
499472
const int decoder_block_shape_q,
500473
const int group_size,
501-
const int block_size,
502-
const int decoder_step_token_num) {
474+
const int block_size) {
503475
return {};
504476
}
505477

@@ -527,8 +499,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
527499
.Attrs({"encoder_block_shape_q: int",
528500
"decoder_block_shape_q: int",
529501
"group_size: int",
530-
"block_size: int",
531-
"decoder_step_token_num: int"})
502+
"block_size: int"})
532503
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
533504
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
534505
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,7 @@ void GetBlockShapeAndSplitKVBlock(
381381
const int encoder_block_shape_q,
382382
const int decoder_block_shape_q,
383383
const int group_size,
384-
const int block_size,
385-
const int decoder_step_token_num);
384+
const int block_size);
386385

387386
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
388387
const paddle::Tensor& token_num,

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata):
5454
_dtype: paddle.dtype = paddle.bfloat16
5555
encoder_max_partition_size: int = 32768
5656
max_partition_size: int = 32768
57-
block_tables: Optional[paddle.Tensor] = None
58-
rotary_embs: Optional[paddle.Tensor] = None
59-
attn_mask: Optional[paddle.Tensor] = None
6057
_fuse_kernel_compute_dtype: str = "bf16"
6158

6259
# pd_disaggregation
@@ -101,7 +98,6 @@ def allocate_launch_related_buffer(
10198
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
10299
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
103100
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
104-
105101
return res
106102

107103

@@ -175,10 +171,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
175171
metadata._fuse_kernel_compute_dtype = "fp16"
176172
elif metadata._dtype == "float32":
177173
metadata._fuse_kernel_compute_dtype = "fp32"
178-
metadata.block_tables = forward_meta.block_tables
179-
metadata.rotary_embs = forward_meta.rotary_embs
180-
metadata.attn_mask = forward_meta.attn_mask
181-
metadata.pre_caches_length = forward_meta.pre_caches_length
182174

183175
# pd_disaggregation
184176
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -263,6 +255,7 @@ def forward_mixed(
263255
cache_v_scales = getattr(layer, "cache_v_scale", None)
264256

265257
if layer.layer_id == 0:
258+
# print(forward_meta.seq_lens_this_time)
266259
get_block_shape_and_split_kv_block(
267260
forward_meta.seq_lens_encoder,
268261
forward_meta.seq_lens_decoder,
@@ -283,7 +276,6 @@ def forward_mixed(
283276
self.decoder_block_shape_q,
284277
self.group_size,
285278
self.block_size,
286-
self.speculate_max_draft_token_num + 1,
287279
)
288280

289281
if self.use_output:
@@ -330,7 +322,7 @@ def forward_mixed(
330322
forward_meta.seq_lens_this_time,
331323
forward_meta.batch_id_per_token,
332324
forward_meta.cu_seqlens_q,
333-
metadata.block_tables,
325+
forward_meta.block_tables,
334326
forward_meta.encoder_batch_ids,
335327
forward_meta.encoder_tile_ids_per_batch,
336328
forward_meta.encoder_num_blocks_x_cpu,
@@ -342,8 +334,8 @@ def forward_mixed(
342334
forward_meta.decoder_num_blocks_cpu,
343335
forward_meta.max_len_tensor_cpu,
344336
res,
345-
metadata.rotary_embs,
346-
metadata.attn_mask,
337+
forward_meta.rotary_embs,
338+
forward_meta.attn_mask,
347339
layer.qkv_bias,
348340
layer.qkv_scale,
349341
cache_k_scales,
@@ -387,7 +379,7 @@ def forward_mixed(
387379
forward_meta.seq_lens_this_time,
388380
forward_meta.batch_id_per_token,
389381
forward_meta.cu_seqlens_q,
390-
metadata.block_tables,
382+
forward_meta.block_tables,
391383
forward_meta.encoder_batch_ids,
392384
forward_meta.encoder_tile_ids_per_batch,
393385
forward_meta.encoder_num_blocks_x_cpu,
@@ -398,8 +390,8 @@ def forward_mixed(
398390
forward_meta.decoder_tile_ids_per_batch,
399391
forward_meta.decoder_num_blocks_cpu,
400392
forward_meta.max_len_tensor_cpu,
401-
metadata.rotary_embs,
402-
metadata.attn_mask,
393+
forward_meta.rotary_embs,
394+
forward_meta.attn_mask,
403395
layer.qkv_bias,
404396
layer.qkv_scale,
405397
cache_k_scales,

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
213213
self.decoder_block_shape_q,
214214
self.group_size,
215215
self.block_size,
216-
self.speculate_max_draft_token_num + 1,
217216
)
218217

219218
(

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,12 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
204204
self.decoder_block_shape_q,
205205
self.group_size,
206206
self.block_size,
207-
self.speculate_max_draft_token_num + 1,
208207
)
209208

210209
# MLA
211210
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
212211
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
213-
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
212+
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]
214213

215214
# pd_disaggregation
216215
metadata.kv_signal_data_list = [None] * self.num_layers

fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def get_block_shape_and_split_kv_block(
4444
decoder_block_shape_q: int,
4545
group_size: int,
4646
block_size: int,
47-
decoder_step_token_num: int,
4847
):
4948
"""
5049
get_block_shape_and_split_kv_block
@@ -70,7 +69,6 @@ def get_block_shape_and_split_kv_block(
7069
decoder_block_shape_q,
7170
group_size,
7271
block_size,
73-
decoder_step_token_num,
7472
)
7573

7674
else:

fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,12 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
179179
self.decoder_block_shape_q,
180180
self.group_size,
181181
self.block_size,
182-
self.speculate_max_draft_token_num + 1,
183182
)
184183

185184
# MLA
186185
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item()
187186
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
188-
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
187+
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]
189188

190189
# pd_disaggregation
191190
metadata.kv_signal_data_list = [None] * self.num_layers

tests/layers/test_append_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
628628
12,
629629
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
630630
self.blocksize,
631-
speculate_max_draft_token_num + 1,
632631
)
633632
if self.use_dynamic_quant:
634633
cache_quant_type = "block_wise_fp8"

tests/layers/test_append_attention_with_output.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
479479
12,
480480
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
481481
self.blocksize,
482-
speculate_max_draft_token_num + 1,
483482
)
484483

485484
# Warm up

0 commit comments

Comments
 (0)