Skip to content

Commit a294368

Browse files
committed
clean model_executor/layers/attention/append_attn_backend.py
1 parent 472d525 commit a294368

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata):
5454
_dtype: paddle.dtype = paddle.bfloat16
5555
encoder_max_partition_size: int = 32768
5656
max_partition_size: int = 32768
57-
block_tables: Optional[paddle.Tensor] = None
58-
rotary_embs: Optional[paddle.Tensor] = None
59-
attn_mask: Optional[paddle.Tensor] = None
6057
_fuse_kernel_compute_dtype: str = "bf16"
6158

6259
# pd_disaggregation
@@ -75,31 +72,34 @@ def allocate_launch_related_buffer(
7572
block_size,
7673
):
7774
# Initialize AttentionBackend buffers
78-
group_size = np.ceil(num_heads / kv_num_heads)
75+
assert num_heads % kv_num_heads == 0
76+
assert max_model_len % block_size == 0
77+
assert max_model_len % encoder_block_shape_q == 0
78+
group_size = num_heads // kv_num_heads
7979

8080
# NOTE: (changwenbin) When using auto_chunk,
8181
# decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
8282
decode_max_tile_size = (
83-
1024 * max_batch_size * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q)
83+
1024 * max_batch_size * (int)(np.ceil(decoder_step_token_num * group_size / decoder_block_shape_q))
8484
)
85-
encode_max_tile_size = max_batch_size * np.ceil((max_model_len * group_size) / encoder_block_shape_q)
86-
kv_max_tile_size = max_batch_size * np.ceil(max_model_len / block_size)
85+
encode_max_tile_size = max_batch_size * (max_model_len * group_size // encoder_block_shape_q)
86+
kv_max_tile_size = max_batch_size * (max_model_len // block_size)
8787
res = {}
88-
res["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
89-
res["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
88+
res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
89+
res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
9090
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
9191
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
9292
# adapted to cudagraph.
9393
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
9494
res["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
9595
res["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu()
9696

97-
res["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
98-
res["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
97+
res["encoder_batch_ids"] = paddle.full([encode_max_tile_size], 0, dtype="int32")
98+
res["encoder_tile_ids_per_batch"] = paddle.full([encode_max_tile_size], 0, dtype="int32")
9999
res["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
100100

101-
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
102-
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
101+
res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32")
102+
res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32")
103103
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
104104

105105
return res
@@ -175,10 +175,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
175175
metadata._fuse_kernel_compute_dtype = "fp16"
176176
elif metadata._dtype == "float32":
177177
metadata._fuse_kernel_compute_dtype = "fp32"
178-
metadata.block_tables = forward_meta.block_tables
179-
metadata.rotary_embs = forward_meta.rotary_embs
180-
metadata.attn_mask = forward_meta.attn_mask
181-
metadata.pre_caches_length = forward_meta.pre_caches_length
182178

183179
# pd_disaggregation
184180
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -330,7 +326,7 @@ def forward_mixed(
330326
forward_meta.seq_lens_this_time,
331327
forward_meta.batch_id_per_token,
332328
forward_meta.cu_seqlens_q,
333-
metadata.block_tables,
329+
forward_meta.block_tables,
334330
forward_meta.encoder_batch_ids,
335331
forward_meta.encoder_tile_ids_per_batch,
336332
forward_meta.encoder_num_blocks_x_cpu,
@@ -342,8 +338,8 @@ def forward_mixed(
342338
forward_meta.decoder_num_blocks_cpu,
343339
forward_meta.max_len_tensor_cpu,
344340
res,
345-
metadata.rotary_embs,
346-
metadata.attn_mask,
341+
forward_meta.rotary_embs,
342+
forward_meta.attn_mask,
347343
layer.qkv_bias,
348344
layer.qkv_scale,
349345
cache_k_scales,
@@ -387,7 +383,7 @@ def forward_mixed(
387383
forward_meta.seq_lens_this_time,
388384
forward_meta.batch_id_per_token,
389385
forward_meta.cu_seqlens_q,
390-
metadata.block_tables,
386+
forward_meta.block_tables,
391387
forward_meta.encoder_batch_ids,
392388
forward_meta.encoder_tile_ids_per_batch,
393389
forward_meta.encoder_num_blocks_x_cpu,
@@ -398,8 +394,8 @@ def forward_mixed(
398394
forward_meta.decoder_tile_ids_per_batch,
399395
forward_meta.decoder_num_blocks_cpu,
400396
forward_meta.max_len_tensor_cpu,
401-
metadata.rotary_embs,
402-
metadata.attn_mask,
397+
forward_meta.rotary_embs,
398+
forward_meta.attn_mask,
403399
layer.qkv_bias,
404400
layer.qkv_scale,
405401
cache_k_scales,

0 commit comments

Comments
 (0)