From e06ca33bee9cd1b82c7e40c950ecf347eb581a05 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 19 Nov 2025 20:10:21 +0800 Subject: [PATCH 01/10] add decoder_batch_ele_num --- .../get_block_shape_and_split_kv_block.cu | 27 +++++-------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 3368eb6200a..c4bf6a9209a 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -307,6 +307,8 @@ void GetBlockShapeAndSplitKVBlock( int max_just_dec_len_without_system = max_len_cpu_ptr[7]; int max_kv_len_this_time = max_len_cpu_ptr[8]; + const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0]; + // decoder if (max_dec_len_this_time > 0) { const bool mla_backend = checkAttentionBackend(); @@ -343,25 +345,15 @@ void GetBlockShapeAndSplitKVBlock( decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false); const int chunk_size = decoder_chunk_size_cpu.data()[0]; - // NOTE: (changwenbin) When using auto_chunk, - // decode_max_tile_size must take into account the maximum case, where * - // 1024 can cover 128K. const uint32_t decoder_batch_shape = - // seq_lens_decoder.dims()[0] * 1024; - - const uint32_t decoder_max_tile_size_per_bs_q = - div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = - bsz * 1024 * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_batch_ids.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); split_block_for_mla<<<1, 32, 0, stream>>>( @@ -374,22 +366,15 @@ void GetBlockShapeAndSplitKVBlock( chunk_size); } else { - // Note:(changwenbin)In order to adapt to cudagraph, the maximum value - // should be taken here - const uint32_t decoder_max_tile_size_per_bs_q = - div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = - bsz * 1024 * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_batch_ids.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, - decoder_batch_shape * sizeof(int32_t), + decoder_batch_ele_num * sizeof(int32_t), stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); From 472d52588f3b4d7d5c9d4f4f579011b908919410 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 19 Nov 2025 22:15:40 +0800 Subject: [PATCH 02/10] remove else --- .../append_attn/get_block_shape_and_split_kv_block.cu | 7 ------- 1 file changed, 7 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index c4bf6a9209a..3937fe201e8 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -398,13 +398,6 @@ void GetBlockShapeAndSplitKVBlock( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } - } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); - decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); } // encoder From a294368d1445fa2e9896e87d15ef84e5f0dcb52f Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 19 Nov 2025 23:39:28 +0800 Subject: [PATCH 03/10] clean model_executor/layers/attention/append_attn_backend.py --- .../layers/attention/append_attn_backend.py | 42 +++++++++---------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 3e3a56aa424..ee50e55c434 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata): _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 - block_tables: Optional[paddle.Tensor] = None - rotary_embs: Optional[paddle.Tensor] = None - attn_mask: Optional[paddle.Tensor] = None _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -75,18 +72,21 @@ def allocate_launch_related_buffer( block_size, ): # Initialize AttentionBackend buffers - group_size = np.ceil(num_heads / kv_num_heads) + assert num_heads % kv_num_heads == 0 + assert max_model_len % block_size == 0 + assert max_model_len % encoder_block_shape_q == 0 + group_size = num_heads // kv_num_heads # NOTE: (changwenbin) When using auto_chunk, # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K. decode_max_tile_size = ( - 1024 * max_batch_size * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) + 1024 * max_batch_size * (int)(np.ceil(decoder_step_token_num * group_size / decoder_block_shape_q)) ) - encode_max_tile_size = max_batch_size * np.ceil((max_model_len * group_size) / encoder_block_shape_q) - kv_max_tile_size = max_batch_size * np.ceil(max_model_len / block_size) + encode_max_tile_size = max_batch_size * (max_model_len * group_size // encoder_block_shape_q) + kv_max_tile_size = max_batch_size * (max_model_len // block_size) res = {} - res["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - res["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32") + res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32") res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, # adapted to cudagraph. @@ -94,12 +94,12 @@ def allocate_launch_related_buffer( res["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") res["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() - res["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") - res["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + res["encoder_batch_ids"] = paddle.full([encode_max_tile_size], 0, dtype="int32") + res["encoder_tile_ids_per_batch"] = paddle.full([encode_max_tile_size], 0, dtype="int32") res["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") - res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32") + res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() return res @@ -175,10 +175,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): metadata._fuse_kernel_compute_dtype = "fp16" elif metadata._dtype == "float32": metadata._fuse_kernel_compute_dtype = "fp32" - metadata.block_tables = forward_meta.block_tables - metadata.rotary_embs = forward_meta.rotary_embs - metadata.attn_mask = forward_meta.attn_mask - metadata.pre_caches_length = forward_meta.pre_caches_length # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -330,7 +326,7 @@ def forward_mixed( forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - metadata.block_tables, + forward_meta.block_tables, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, forward_meta.encoder_num_blocks_x_cpu, @@ -342,8 +338,8 @@ def forward_mixed( forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, res, - metadata.rotary_embs, - metadata.attn_mask, + forward_meta.rotary_embs, + forward_meta.attn_mask, layer.qkv_bias, layer.qkv_scale, cache_k_scales, @@ -387,7 +383,7 @@ def forward_mixed( forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - metadata.block_tables, + forward_meta.block_tables, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, forward_meta.encoder_num_blocks_x_cpu, @@ -398,8 +394,8 @@ def forward_mixed( forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.max_len_tensor_cpu, - metadata.rotary_embs, - metadata.attn_mask, + forward_meta.rotary_embs, + forward_meta.attn_mask, layer.qkv_bias, layer.qkv_scale, cache_k_scales, From f8ebf65db878c7edf6491205d5969d8b1e53028b Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 07:44:28 +0800 Subject: [PATCH 04/10] remove decoder_step_num --- .../get_block_shape_and_split_kv_block.cu | 12 ++++-------- custom_ops/gpu_ops/cpp_extensions.cc | 3 +-- .../layers/attention/append_attn_backend.py | 3 +-- .../layers/attention/flash_attn_backend.py | 1 - .../layers/attention/mla_attention_backend.py | 1 - .../ops/get_block_shape_and_split_kv_block.py | 4 +--- tests/layers/test_append_attention.py | 1 - tests/layers/test_append_attention_with_output.py | 1 - tests/operators/test_tree_mask.py | 1 - 9 files changed, 7 insertions(+), 20 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 3937fe201e8..c22a6a188f5 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num) { + const int block_size) { auto stream = seq_lens_encoder.stream(); int bsz = seq_lens_this_time.shape()[0]; @@ -464,8 +463,7 @@ std::vector> GetBlockShapeAndSplitKVBlockInferShape( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num) { + const int block_size) { return {}; } @@ -476,8 +474,7 @@ std::vector GetBlockShapeAndSplitKVBlockInferDtype( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num) { + const int block_size) { return {}; } @@ -505,8 +502,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) .Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int", "group_size: int", - "block_size: int", - "decoder_step_token_num: int"}) + "block_size: int"}) .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 6ecc1ed1451..0e6853d9be5 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -381,8 +381,7 @@ void GetBlockShapeAndSplitKVBlock( const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, - const int block_size, - const int decoder_step_token_num); + const int block_size); std::vector GetPaddingOffset(const paddle::Tensor& input_ids, const paddle::Tensor& token_num, diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index ee50e55c434..dbcaa7505af 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -278,8 +278,7 @@ def forward_mixed( self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, - self.block_size, - self.speculate_max_draft_token_num + 1, + self.block_size ) if self.use_output: diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index bce361eb5dd..31d6d748885 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -213,7 +213,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) ( diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 54e72379eab..035109f2b6f 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -204,7 +204,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) # MLA diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index 1cd5f4f142b..721699ba5f0 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -44,7 +44,6 @@ def get_block_shape_and_split_kv_block( decoder_block_shape_q: int, group_size: int, block_size: int, - decoder_step_token_num: int, ): """ get_block_shape_and_split_kv_block @@ -69,8 +68,7 @@ def get_block_shape_and_split_kv_block( encoder_block_shape_q, decoder_block_shape_q, group_size, - block_size, - decoder_step_token_num, + block_size ) else: diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 4cc00858de8..01ad4bb932b 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -628,7 +628,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, self.blocksize, - speculate_max_draft_token_num + 1, ) if self.use_dynamic_quant: cache_quant_type = "block_wise_fp8" diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index 5f08c737179..6c15de17ccc 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -479,7 +479,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask 12, (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, self.blocksize, - speculate_max_draft_token_num + 1, ) # Warm up diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 1cfbaaf7a56..57a62044814 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -254,7 +254,6 @@ def run_append_c16_attention( decoder_block_shape_q, self.num_q_head // self.num_kv_head, self.block_size, - decoder_step_token_num, ) s_time = 0 for i in range(self.run_time + self.warm_up): From f647d9cf34f6e7939992ee128dd8b5994b2584a0 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 07:45:32 +0800 Subject: [PATCH 05/10] remove decoder_step_num --- .../layers/backends/metax/attention/mla_attn_metax_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index ff1bce8bdc1..40a5bdc3f23 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -179,7 +179,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.decoder_block_shape_q, self.group_size, self.block_size, - self.speculate_max_draft_token_num + 1, ) # MLA From fe89e3b282d928f708a6fd32b5678240d778745c Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 07:48:54 +0800 Subject: [PATCH 06/10] remove decoder_step_num --- .../model_executor/layers/attention/append_attn_backend.py | 2 +- .../layers/attention/ops/get_block_shape_and_split_kv_block.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index dbcaa7505af..000baab2c22 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -278,7 +278,7 @@ def forward_mixed( self.encoder_block_shape_q, self.decoder_block_shape_q, self.group_size, - self.block_size + self.block_size, ) if self.use_output: diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index 721699ba5f0..a97cf16664f 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -68,7 +68,7 @@ def get_block_shape_and_split_kv_block( encoder_block_shape_q, decoder_block_shape_q, group_size, - block_size + block_size, ) else: From 4cb47739560ba549bfc55dede69c86729685516d Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 08:18:19 +0800 Subject: [PATCH 07/10] remove decoder_step_num --- tests/layers/test_attention_layer.py | 66 ++++++++++++++-------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 32d579f74b6..91bd43eb6e3 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -121,10 +121,10 @@ def create_model_config_json(self) -> str: "dtype": "bfloat16", "hidden_size": 4096, "max_position_embeddings": 131072, - "max_model_len": 5500, + "max_model_len": 36 * 1024 + 1024, "num_attention_heads": 32, "num_key_value_heads": 4, - "num_hidden_layers": 5, + "num_hidden_layers": 57, } model_dir = tempfile.mkdtemp(prefix="tmp_model_config_") config_path = os.path.join(model_dir, "config.json") @@ -223,7 +223,7 @@ def create_forward_meta( max_model_len=fd_config.model_config.max_model_len, encoder_block_shape_q=64, decoder_block_shape_q=16, - decoder_step_token_num=1, + decoder_step_token_num=fd_config.speculative_config.num_speculative_tokens + 1, num_heads=fd_config.model_config.num_attention_heads, kv_num_heads=fd_config.model_config.num_key_value_heads, block_size=fd_config.cache_config.block_size, @@ -294,29 +294,30 @@ def create_forward_meta( def test_decode_performance_with_prefill(self): # Test parameters test_steps = 100 - prefill_batch_size = 1 - prefill_seq_len = 4096 use_dynamic_quant = True act_tensor_dtype = paddle.bfloat16 - prefill_hidden_states = paddle.randn( - [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size], - dtype=act_tensor_dtype, - ) + # prefill_batch_size = 1 + # prefill_seq_len = 4096 - forward_meta = self.create_forward_meta( - batch_size=prefill_batch_size, - seq_len=prefill_seq_len, - mode=ForwardMode.EXTEND, - fd_config=self.fd_config, - attn_backend=self.attn_backend, - use_dynamic_quant=use_dynamic_quant, - ) + # prefill_hidden_states = paddle.randn( + # [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size], + # dtype=act_tensor_dtype, + # ) - self.attn_backend.init_attention_metadata(forward_meta) - self.attn_forward(forward_meta, prefill_hidden_states) + # forward_meta = self.create_forward_meta( + # batch_size=prefill_batch_size, + # seq_len=prefill_seq_len, + # mode=ForwardMode.EXTEND, + # fd_config=self.fd_config, + # attn_backend=self.attn_backend, + # use_dynamic_quant=use_dynamic_quant, + # ) + + # self.attn_backend.init_attention_metadata(forward_meta) + # self.attn_forward(forward_meta, prefill_hidden_states) - paddle.device.synchronize() + # paddle.device.synchronize() # import paddle.profiler as profiler # p = profiler.Profiler( @@ -326,18 +327,18 @@ def test_decode_performance_with_prefill(self): # p.start() # p.step() - start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] - end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] - for i in range(test_steps): - start_events[i].record() + # start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] + # end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)] + # for i in range(test_steps): + # start_events[i].record() - self.attn_forward(forward_meta, prefill_hidden_states) + # self.attn_forward(forward_meta, prefill_hidden_states) - end_events[i].record() - paddle.device.synchronize() + # end_events[i].record() + # paddle.device.synchronize() - times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] - print(times[-5:]) + # times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] + # print(times[-5:]) # p.stop() @@ -349,14 +350,14 @@ def test_decode_performance_with_prefill(self): # p.start() # p.step() - for decode_batch_size in [10, 20, 40, 60, 80, 100, 128]: + for decode_batch_size in [32, 16, 8, 4, 2]: decode_hidden_states = paddle.randn( [decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype ) forward_meta = self.create_forward_meta( batch_size=decode_batch_size, - seq_len=5000, + seq_len=36 * 1024, mode=ForwardMode.DECODE, fd_config=self.fd_config, attn_backend=self.attn_backend, @@ -383,7 +384,6 @@ def test_decode_performance_with_prefill(self): start_events[i].record() attn_cuda_graphs.replay() - # self.attn_forward(forward_meta, decode_hidden_states) end_events[i].record() paddle.device.synchronize() @@ -391,6 +391,8 @@ def test_decode_performance_with_prefill(self): times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] print(times[-5:]) + del forward_meta + # p.stop() From 08e6a101a6e169b861ce60d05fc12072f06e1f53 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 16:47:00 +0800 Subject: [PATCH 08/10] do not modify allocate_launch_related_buffer --- .../gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 4c8854d8465..e84f8281648 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -305,8 +305,6 @@ void GetBlockShapeAndSplitKVBlock( const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0]; - const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0]; - // decoder if (max_dec_len_this_time > 0) { const bool mla_backend = checkAttentionBackend(); From af543cb2bbb78d68a246a04cd2b4fbdebdb110f8 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 16:47:39 +0800 Subject: [PATCH 09/10] do not modify allocate_launch_related_buffer --- tests/entrypoints/openai/test_run_batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index fc6803452a6..4cd82f49165 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -172,7 +172,7 @@ async def test_initialize_engine_client(self, mock_engine_client): mock_args = Mock() mock_args.model = "test-model" mock_args.tokenizer = "test-tokenizer" - mock_args.max_model_len = 1000 + mock_args.max_model_len = 1024 mock_args.tensor_parallel_size = 1 mock_args.engine_worker_queue_port = [8000] mock_args.local_data_parallel_id = 0 @@ -202,7 +202,7 @@ async def test_initialize_engine_client(self, mock_engine_client): def test_create_serving_handlers(self, mock_chat_handler, mock_model_handler): """测试创建服务处理器""" mock_args = Mock() - mock_args.max_model_len = 1000 + mock_args.max_model_len = 1024 mock_args.ips = "127.0.0.1" mock_args.max_waiting_time = 60 mock_args.enable_mm_output = False @@ -1286,7 +1286,7 @@ def run_fastdeploy_command(self, input_content, port=None): "--quantization", "wint4", "--max-model-len", - "4192", + "5120", "--max-num-seqs", "64", "--load-choices", From 74cb09553c553a48b0a92a201deab09cfc7293c3 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 20 Nov 2025 16:50:23 +0800 Subject: [PATCH 10/10] do not modify allocate_launch_related_buffer --- .../get_block_shape_and_split_kv_block.cu | 79 ++++++++----------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index e84f8281648..2b5c1fbc7d0 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -311,16 +311,17 @@ void GetBlockShapeAndSplitKVBlock( if (mla_backend && group_size <= 64) { const int set_chunk_size = get_mla_dec_chunk_size(bsz); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + CUDA_CHECK(cudaMemsetAsync( decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + CUDA_CHECK(cudaMemsetAsync( decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); int device; - cudaGetDevice(&device); + CUDA_CHECK(cudaGetDevice(&device)); int sm_cout; - cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device); + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); constexpr int config_size = 12; // search space for chunk size:[64, 128, 256, ... 131072] @@ -341,16 +342,14 @@ void GetBlockShapeAndSplitKVBlock( decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false); const int chunk_size = decoder_chunk_size_cpu.data()[0]; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(decoder_batch_ids.data(), - 0, - decoder_batch_ele_num * sizeof(int32_t), - stream)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(decoder_tile_ids_per_batch.data(), - 0, - decoder_batch_ele_num * sizeof(int32_t), - stream)); + CUDA_CHECK(cudaMemsetAsync(decoder_batch_ids.data(), + 0, + decoder_batch_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), + 0, + decoder_batch_ele_num * sizeof(int32_t), + stream)); split_block_for_mla<<<1, 32, 0, stream>>>( seq_lens_this_time.data(), @@ -362,17 +361,15 @@ void GetBlockShapeAndSplitKVBlock( chunk_size); } else { - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(decoder_batch_ids.data(), - 0, - decoder_batch_ele_num * sizeof(int32_t), - stream)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(decoder_tile_ids_per_batch.data(), - 0, - decoder_batch_ele_num * sizeof(int32_t), - stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + CUDA_CHECK(cudaMemsetAsync(decoder_batch_ids.data(), + 0, + decoder_batch_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), + 0, + decoder_batch_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK(cudaMemsetAsync( decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); split_q_block<<<1, 32, 0, stream>>>( @@ -391,8 +388,6 @@ void GetBlockShapeAndSplitKVBlock( #endif decoder_num_blocks_cpu.copy_( decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } } @@ -401,19 +396,17 @@ void GetBlockShapeAndSplitKVBlock( const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + CUDA_CHECK(cudaMemsetAsync( kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(kv_tile_ids_per_batch.data(), - 0, - kv_batch_shape * sizeof(int32_t), - stream)); + CUDA_CHECK(cudaMemsetAsync(kv_tile_ids_per_batch.data(), + 0, + kv_batch_shape * sizeof(int32_t), + stream)); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>( seq_lens_decoder.data(), - // sequence_lengths->data(), seq_lens_encoder.data(), kv_batch_ids.data(), kv_tile_ids_per_batch.data(), @@ -428,16 +421,14 @@ void GetBlockShapeAndSplitKVBlock( const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(encoder_batch_ids.data(), - 0, - encoder_batch_shape * sizeof(int32_t), - stream)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(encoder_tile_ids_per_batch.data(), - 0, - encoder_batch_shape * sizeof(int32_t), - stream)); + CUDA_CHECK(cudaMemsetAsync(encoder_batch_ids.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); + CUDA_CHECK(cudaMemsetAsync(encoder_tile_ids_per_batch.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(),