Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4377d11
first commit
zhoutianzi666 Nov 19, 2025
a818eaf
first commit
zhoutianzi666 Nov 19, 2025
ccc3054
first commit
zhoutianzi666 Nov 19, 2025
aa1af1c
first commit
zhoutianzi666 Nov 19, 2025
982341f
first commit
zhoutianzi666 Nov 19, 2025
271566e
first commit
zhoutianzi666 Nov 19, 2025
3bee31e
first commit
zhoutianzi666 Nov 19, 2025
2a18524
first commit
zhoutianzi666 Nov 19, 2025
111925c
Merge remote-tracking branch 'origin/develop' into remove_code
zhoutianzi666 Nov 19, 2025
7526200
first commit
zhoutianzi666 Nov 19, 2025
386bcda
first commit
zhoutianzi666 Nov 19, 2025
16fcd6f
first commit
zhoutianzi666 Nov 19, 2025
5bddd3e
first commit
zhoutianzi666 Nov 19, 2025
8561cee
first commit
zhoutianzi666 Nov 19, 2025
02d067a
first commit
zhoutianzi666 Nov 19, 2025
a7945c9
Merge remote-tracking branch 'origin/develop' into remove_code
zhoutianzi666 Nov 19, 2025
37ac164
first commit
zhoutianzi666 Nov 19, 2025
b76ab14
first commit
zhoutianzi666 Nov 19, 2025
11e9dd4
Merge remote-tracking branch 'origin/develop' into remove_code
zhoutianzi666 Nov 19, 2025
bd2e9eb
Merge remote-tracking branch 'origin/develop' into remove_code
zhoutianzi666 Nov 19, 2025
b2966b7
Merge branch 'develop' into remove_code
zhoutianzi666 Nov 20, 2025
24a0d58
Merge branch 'develop' into remove_code
EmmonsCurse Nov 20, 2025
bcf830d
do not modify allocate_launch_related_buffer
zhoutianzi666 Nov 20, 2025
7a42c27
Merge remote-tracking branch 'origin/develop' into remove_code
zhoutianzi666 Nov 20, 2025
f2635ca
Merge remote-tracking branch 'myfd/remove_code' into remove_code
zhoutianzi666 Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
563 changes: 308 additions & 255 deletions custom_ops/gpu_ops/append_attention.cu

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder,
max_lens[2] = total_max_len_decoder;
max_lens[3] = total;
max_lens[4] = total_just_dec;
max_lens[8] = total_max_len_kv;
max_lens[5] = total_max_len_kv;
}
}

Expand Down Expand Up @@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
const int block_size) {
auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0];

Expand Down Expand Up @@ -302,10 +301,9 @@ void GetBlockShapeAndSplitKVBlock(
int max_dec_len_this_time = max_len_cpu_ptr[2];
int max_enc_dec_len_this_time = max_len_cpu_ptr[3];
int max_just_dec_len_this_time = max_len_cpu_ptr[4];
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
int max_kv_len_this_time = max_len_cpu_ptr[8];
int max_kv_len_this_time = max_len_cpu_ptr[5];

const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0];

// decoder
if (max_dec_len_this_time > 0) {
Expand Down Expand Up @@ -343,25 +341,15 @@ void GetBlockShapeAndSplitKVBlock(
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];

// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where *
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
// seq_lens_decoder.dims()[0] * 1024;

const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;

PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));

split_block_for_mla<<<1, 32, 0, stream>>>(
Expand All @@ -374,22 +362,15 @@ void GetBlockShapeAndSplitKVBlock(
chunk_size);

} else {
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
// should be taken here
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;

PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
decoder_batch_ele_num * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
Expand All @@ -413,13 +394,6 @@ void GetBlockShapeAndSplitKVBlock(
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}

// encoder
Expand Down Expand Up @@ -486,8 +460,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
const int block_size) {
return {};
}

Expand All @@ -498,8 +471,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
const int block_size) {
return {};
}

Expand Down Expand Up @@ -527,8 +499,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Attrs({"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"})
"block_size: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
3 changes: 1 addition & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,7 @@ void GetBlockShapeAndSplitKVBlock(
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num);
const int block_size);

std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& token_num,
Expand Down
22 changes: 7 additions & 15 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class AppendAttentionMetadata(AttentionMetadata):
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"

# pd_disaggregation
Expand Down Expand Up @@ -101,7 +98,6 @@ def allocate_launch_related_buffer(
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()

return res


Expand Down Expand Up @@ -175,10 +171,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length

# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
Expand Down Expand Up @@ -263,6 +255,7 @@ def forward_mixed(
cache_v_scales = getattr(layer, "cache_v_scale", None)

if layer.layer_id == 0:
# print(forward_meta.seq_lens_this_time)
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
Expand All @@ -283,7 +276,6 @@ def forward_mixed(
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)

if self.use_output:
Expand Down Expand Up @@ -330,7 +322,7 @@ def forward_mixed(
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
Expand All @@ -342,8 +334,8 @@ def forward_mixed(
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
res,
metadata.rotary_embs,
metadata.attn_mask,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
cache_k_scales,
Expand Down Expand Up @@ -387,7 +379,7 @@ def forward_mixed(
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
Expand All @@ -398,8 +390,8 @@ def forward_mixed(
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.rotary_embs,
metadata.attn_mask,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
cache_k_scales,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)

(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,12 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)

# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]

# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def get_block_shape_and_split_kv_block(
decoder_block_shape_q: int,
group_size: int,
block_size: int,
decoder_step_token_num: int,
):
"""
get_block_shape_and_split_kv_block
Expand All @@ -70,7 +69,6 @@ def get_block_shape_and_split_kv_block(
decoder_block_shape_q,
group_size,
block_size,
decoder_step_token_num,
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,12 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)

# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1].item()
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5]

# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
Expand Down
1 change: 0 additions & 1 deletion tests/layers/test_append_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num + 1,
)
if self.use_dynamic_quant:
cache_quant_type = "block_wise_fp8"
Expand Down
1 change: 0 additions & 1 deletion tests/layers/test_append_attention_with_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num + 1,
)

# Warm up
Expand Down
Loading
Loading