@@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata):
5454 _dtype : paddle .dtype = paddle .bfloat16
5555 encoder_max_partition_size : int = 32768
5656 max_partition_size : int = 32768
57- block_tables : Optional [paddle .Tensor ] = None
58- rotary_embs : Optional [paddle .Tensor ] = None
59- attn_mask : Optional [paddle .Tensor ] = None
6057 _fuse_kernel_compute_dtype : str = "bf16"
6158
6259 # pd_disaggregation
@@ -75,31 +72,34 @@ def allocate_launch_related_buffer(
7572 block_size ,
7673):
7774 # Initialize AttentionBackend buffers
78- group_size = np .ceil (num_heads / kv_num_heads )
75+ assert num_heads % kv_num_heads == 0
76+ assert max_model_len % block_size == 0
77+ assert max_model_len % encoder_block_shape_q == 0
78+ group_size = num_heads // kv_num_heads
7979
8080 # NOTE: (changwenbin) When using auto_chunk,
8181 # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
8282 decode_max_tile_size = (
83- 1024 * max_batch_size * np .ceil (( decoder_step_token_num * group_size ) / decoder_block_shape_q )
83+ 1024 * max_batch_size * ( int )( np .ceil (decoder_step_token_num * group_size / decoder_block_shape_q ) )
8484 )
85- encode_max_tile_size = max_batch_size * np . ceil (( max_model_len * group_size ) / encoder_block_shape_q )
86- kv_max_tile_size = max_batch_size * np . ceil (max_model_len / block_size )
85+ encode_max_tile_size = max_batch_size * ( max_model_len * group_size / / encoder_block_shape_q )
86+ kv_max_tile_size = max_batch_size * (max_model_len / / block_size )
8787 res = {}
88- res ["decoder_batch_ids" ] = paddle .full ([int ( decode_max_tile_size ) ], 0 , dtype = "int32" )
89- res ["decoder_tile_ids_per_batch" ] = paddle .full ([int ( decode_max_tile_size ) ], 0 , dtype = "int32" )
88+ res ["decoder_batch_ids" ] = paddle .full ([decode_max_tile_size ], 0 , dtype = "int32" )
89+ res ["decoder_tile_ids_per_batch" ] = paddle .full ([decode_max_tile_size ], 0 , dtype = "int32" )
9090 res ["decoder_num_blocks_cpu" ] = paddle .full ([1 ], 0 , dtype = "int32" ).pin_memory ()
9191 # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
9292 # adapted to cudagraph.
9393 res ["decoder_num_blocks_device" ] = paddle .full ([1 ], 0 , dtype = "int32" )
9494 res ["decoder_chunk_size_device" ] = paddle .full ([1 ], 64 , dtype = "int32" )
9595 res ["max_len_tensor_cpu" ] = paddle .full ([9 ], 0 , dtype = "int32" ).cpu ()
9696
97- res ["encoder_batch_ids" ] = paddle .full ([int ( encode_max_tile_size ) ], 0 , dtype = "int32" )
98- res ["encoder_tile_ids_per_batch" ] = paddle .full ([int ( encode_max_tile_size ) ], 0 , dtype = "int32" )
97+ res ["encoder_batch_ids" ] = paddle .full ([encode_max_tile_size ], 0 , dtype = "int32" )
98+ res ["encoder_tile_ids_per_batch" ] = paddle .full ([encode_max_tile_size ], 0 , dtype = "int32" )
9999 res ["encoder_num_blocks_x_cpu" ] = paddle .full ([1 ], 0 , dtype = "int32" ).cpu ()
100100
101- res ["kv_batch_ids" ] = paddle .full ([int ( kv_max_tile_size ) ], 0 , dtype = "int32" )
102- res ["kv_tile_ids_per_batch" ] = paddle .full ([int ( kv_max_tile_size ) ], 0 , dtype = "int32" )
101+ res ["kv_batch_ids" ] = paddle .full ([kv_max_tile_size ], 0 , dtype = "int32" )
102+ res ["kv_tile_ids_per_batch" ] = paddle .full ([kv_max_tile_size ], 0 , dtype = "int32" )
103103 res ["kv_num_blocks_x_cpu" ] = paddle .full ([1 ], 0 , dtype = "int32" ).cpu ()
104104
105105 return res
@@ -175,10 +175,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
175175 metadata ._fuse_kernel_compute_dtype = "fp16"
176176 elif metadata ._dtype == "float32" :
177177 metadata ._fuse_kernel_compute_dtype = "fp32"
178- metadata .block_tables = forward_meta .block_tables
179- metadata .rotary_embs = forward_meta .rotary_embs
180- metadata .attn_mask = forward_meta .attn_mask
181- metadata .pre_caches_length = forward_meta .pre_caches_length
182178
183179 # pd_disaggregation
184180 metadata .kv_signal_data_list = [None ] * self .num_layers
@@ -330,7 +326,7 @@ def forward_mixed(
330326 forward_meta .seq_lens_this_time ,
331327 forward_meta .batch_id_per_token ,
332328 forward_meta .cu_seqlens_q ,
333- metadata .block_tables ,
329+ forward_meta .block_tables ,
334330 forward_meta .encoder_batch_ids ,
335331 forward_meta .encoder_tile_ids_per_batch ,
336332 forward_meta .encoder_num_blocks_x_cpu ,
@@ -342,8 +338,8 @@ def forward_mixed(
342338 forward_meta .decoder_num_blocks_cpu ,
343339 forward_meta .max_len_tensor_cpu ,
344340 res ,
345- metadata .rotary_embs ,
346- metadata .attn_mask ,
341+ forward_meta .rotary_embs ,
342+ forward_meta .attn_mask ,
347343 layer .qkv_bias ,
348344 layer .qkv_scale ,
349345 cache_k_scales ,
@@ -387,7 +383,7 @@ def forward_mixed(
387383 forward_meta .seq_lens_this_time ,
388384 forward_meta .batch_id_per_token ,
389385 forward_meta .cu_seqlens_q ,
390- metadata .block_tables ,
386+ forward_meta .block_tables ,
391387 forward_meta .encoder_batch_ids ,
392388 forward_meta .encoder_tile_ids_per_batch ,
393389 forward_meta .encoder_num_blocks_x_cpu ,
@@ -398,8 +394,8 @@ def forward_mixed(
398394 forward_meta .decoder_tile_ids_per_batch ,
399395 forward_meta .decoder_num_blocks_cpu ,
400396 forward_meta .max_len_tensor_cpu ,
401- metadata .rotary_embs ,
402- metadata .attn_mask ,
397+ forward_meta .rotary_embs ,
398+ forward_meta .attn_mask ,
403399 layer .qkv_bias ,
404400 layer .qkv_scale ,
405401 cache_k_scales ,
0 commit comments