@@ -80,26 +80,26 @@ def allocate_launch_related_buffer(
8080 # NOTE: (changwenbin) When using auto_chunk,
8181 # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
8282 decode_max_tile_size = (
83- 1024 * max_batch_size * np .ceil (( decoder_step_token_num * group_size ) / decoder_block_shape_q )
83+ 1024 * max_batch_size * ( int )( np .ceil (decoder_step_token_num * group_size / decoder_block_shape_q ) )
8484 )
8585 encode_max_tile_size = max_batch_size * (max_model_len * group_size // encoder_block_shape_q )
8686 kv_max_tile_size = max_batch_size * max_model_len // block_size
8787 res = {}
88- res ["decoder_batch_ids" ] = paddle .full ([int ( decode_max_tile_size ) ], 0 , dtype = "int32" )
89- res ["decoder_tile_ids_per_batch" ] = paddle .full ([int ( decode_max_tile_size ) ], 0 , dtype = "int32" )
88+ res ["decoder_batch_ids" ] = paddle .full ([decode_max_tile_size ], 0 , dtype = "int32" )
89+ res ["decoder_tile_ids_per_batch" ] = paddle .full ([decode_max_tile_size ], 0 , dtype = "int32" )
9090 res ["decoder_num_blocks_cpu" ] = paddle .full ([1 ], 0 , dtype = "int32" ).pin_memory ()
9191 # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
9292 # adapted to cudagraph.
9393 res ["decoder_num_blocks_device" ] = paddle .full ([1 ], 0 , dtype = "int32" )
9494 res ["decoder_chunk_size_device" ] = paddle .full ([1 ], 64 , dtype = "int32" )
9595 res ["max_len_tensor_cpu" ] = paddle .full ([5 ], 0 , dtype = "int32" ).cpu ()
9696
97- res ["encoder_batch_ids" ] = paddle .full ([int ( encode_max_tile_size ) ], 0 , dtype = "int32" )
98- res ["encoder_tile_ids_per_batch" ] = paddle .full ([int ( encode_max_tile_size ) ], 0 , dtype = "int32" )
97+ res ["encoder_batch_ids" ] = paddle .full ([encode_max_tile_size ], 0 , dtype = "int32" )
98+ res ["encoder_tile_ids_per_batch" ] = paddle .full ([encode_max_tile_size ], 0 , dtype = "int32" )
9999 res ["encoder_num_blocks_x_cpu" ] = paddle .full ([1 ], 0 , dtype = "int32" ).cpu ()
100100
101- res ["kv_batch_ids" ] = paddle .full ([int ( kv_max_tile_size ) ], 0 , dtype = "int32" )
102- res ["kv_tile_ids_per_batch" ] = paddle .full ([int ( kv_max_tile_size ) ], 0 , dtype = "int32" )
101+ res ["kv_batch_ids" ] = paddle .full ([kv_max_tile_size ], 0 , dtype = "int32" )
102+ res ["kv_tile_ids_per_batch" ] = paddle .full ([kv_max_tile_size ], 0 , dtype = "int32" )
103103 res ["kv_num_blocks_x_cpu" ] = paddle .full ([1 ], 0 , dtype = "int32" ).cpu ()
104104
105105 return res
0 commit comments