diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index b9c951d391..085d9538bc 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -11,10 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include "cute/tensor.hpp" #include "helper.h" #include "paddle/extension.h" #include "paddle/phi/core/memory/memcpy.h" +#include "utils.cuh" template __global__ void @@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor, max_len_tensor.data(), batch_size); } +template +__global__ void search_chunk_size_for_mla( + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ num_blocks_x, + int *__restrict__ res_chunk_size, + const int bsz, + const int set_chunk_size, + const int block_size, + const int sm_cout) { + const uint32_t conf_id = threadIdx.x; + int gridx = 0; + if (set_chunk_size > 0 && conf_id == 0) { + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + if (seq_len == 0 || seq_len_encoder > 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size); + gridx += loop_times; + } + *num_blocks_x = gridx; + *res_chunk_size = set_chunk_size; + } else if (conf_id < config_size) { + __shared__ int gridx_shared[config_size]; + // chunk_size is a multiple of 64 + const int chunk_size = block_size << conf_id; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + if (seq_len == 0 || seq_len_encoder > 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, chunk_size); + gridx += loop_times; + } + gridx_shared[conf_id] = gridx; + __syncthreads(); + if (threadIdx.x == 0) { + uint32_t res_id = 0; + uint32_t max_last_wave_block = 0; + for (uint32_t i = 1; i < config_size; i++) { + uint32_t last_wave_block = gridx_shared[i] % sm_cout; + if (last_wave_block >= max_last_wave_block) { + res_id = i; + max_last_wave_block = last_wave_block; + } + } + *num_blocks_x = gridx_shared[res_id]; + *res_chunk_size = block_size << res_id; + } + } +} + +__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ batch_ids, + int *__restrict__ tile_ids_per_batch, + const int bsz, + const int chunk_size) { + if (threadIdx.x == 0) { + int index = 0; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + + if (seq_len == 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, chunk_size); + if (seq_len_encoder > 0) { + loop_times = 0; + } + for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; + } + } + } +} + __global__ void split_q_block(const int *__restrict__ seq_lens_q, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, @@ -230,6 +318,9 @@ std::vector GetBlockShapeAndSplitKVBlock( paddle::Tensor kv_tile_ids_per_batch; paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ paddle::Tensor max_len_kv_cpu; /*cpu*/ + paddle::Tensor decoder_num_blocks_x; + paddle::Tensor decoder_chunk_size_device; + paddle::Tensor decoder_chunk_size_cpu; /*cpu*/ auto max_len_kv = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); @@ -239,6 +330,103 @@ std::vector GetBlockShapeAndSplitKVBlock( max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false); + // decoder + if (max_dec_len_this_time > 0) { + const bool mla_use_tensorcore = GetMlaUseTensorcore(); + if (mla_use_tensorcore && group_size <= 64) { + const int set_chunk_size = get_mla_dec_chunk_size(bsz); + decoder_chunk_size_device = GetEmptyTensor( + {1}, paddle::DataType::INT32, seq_lens_encoder.place()); + decoder_num_blocks_x = GetEmptyTensor( + {1}, paddle::DataType::INT32, seq_lens_encoder.place()); + + int device; + cudaGetDevice(&device); + int sm_cout; + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device); + constexpr int config_size = + 12; // search space for chunk size:[64, 128, 256, ... 131072] + + search_chunk_size_for_mla + <<<1, 32, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + decoder_num_blocks_x.data(), + decoder_chunk_size_device.data(), + bsz, + set_chunk_size, + block_size, + sm_cout); + + decoder_chunk_size_cpu = + decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false); + const int chunk_size = decoder_chunk_size_cpu.data()[0]; + + const uint32_t decoder_max_tile_size_per_bs_q = + div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + const uint32_t decoder_batch_shape = bsz * + decoder_max_tile_size_per_bs_q; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), + 0, decoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), + 0, decoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data(), + 0, sizeof(int32_t), stream)); + + std::cout << "-----------------------------------------------------------" + << std::endl; + std::cout << "chunk size1:================================ " << chunk_size + << std::endl; + std::cout << "-----------------------------------------------------------" + << std::endl; + split_block_for_mla<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + bsz, + chunk_size); + decoder_num_blocks_x_cpu.copy_( + decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); + + } else { + const uint32_t decoder_max_tile_size_per_bs_q = + div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + decoder_batch_ids = GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + decoder_tile_ids_per_batch = + GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, + paddle::DataType::INT32, + seq_lens_encoder.place()); + decoder_num_blocks_x = GetEmptyTensor( + {1}, paddle::DataType::INT32, seq_lens_encoder.place()); + split_q_block<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + decoder_num_blocks_x.data(), + bsz, + decoder_block_shape_q, + group_size); + decoder_num_blocks_x_cpu.copy_( + decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); + + decoder_chunk_size_cpu = paddle::full( + {1}, 131072, paddle::DataType::INT32, paddle::CPUPlace()); + } + } else { + decoder_chunk_size_cpu = + paddle::full({1}, 131072, paddle::DataType::INT32, paddle::CPUPlace()); + decoder_num_blocks_x = paddle::full( + {1}, -1, paddle::DataType::INT32, seq_lens_encoder.place()); + decoder_num_blocks_x_cpu.copy_( + decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); + } + + // encoder if (max_enc_len_this_time > 0) { const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); @@ -292,27 +480,27 @@ std::vector GetBlockShapeAndSplitKVBlock( GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); } - if (max_just_dec_len_this_time > 0) { - // Clear buffer - const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data(), 0, sizeof(int32_t), stream)); + // if (max_just_dec_len_this_time > 0) { + // // Clear buffer + // const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + // const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q; + // PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); + // PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); + // PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data(), 0, sizeof(int32_t), stream)); - auto decoder_num_blocks_x = - GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>( - seq_lens_this_time.data(), - seq_lens_encoder.data(), - decoder_batch_ids.data(), - decoder_tile_ids_per_batch.data(), - decoder_num_blocks_x.data(), - bsz, - decoder_block_shape_q, - group_size); - decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); - } + // auto decoder_num_blocks_x = + // GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); + // split_q_block<<<1, 32, 0, stream>>>( + // seq_lens_this_time.data(), + // seq_lens_encoder.data(), + // decoder_batch_ids.data(), + // decoder_tile_ids_per_batch.data(), + // decoder_num_blocks_x.data(), + // bsz, + // decoder_block_shape_q, + // group_size); + // decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); + // } return { encoder_batch_ids, @@ -321,6 +509,8 @@ std::vector GetBlockShapeAndSplitKVBlock( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, /*cpu*/ + decoder_num_blocks_x, + decoder_chunk_size_cpu, /*cpu*/ max_len_kv_cpu, /*cpu*/ }; } @@ -342,6 +532,8 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) paddle::Optional("kv_batch_ids"), paddle::Optional("kv_tile_ids_per_batch"), paddle::Optional("kv_num_blocks_x_cpu"), + paddle::Optional("decoder_num_blocks_x"), + paddle::Optional("decoder_chunk_size_cpu"), "max_len_kv_cpu" }) .Attrs({ diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index bb2e6944ea..102f718a25 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -416,6 +416,7 @@ std::vector MultiHeadLatentAttention( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& decoder_chunk_size_cpu, const paddle::Tensor& max_enc_len_this_time, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, diff --git a/custom_ops/gpu_ops/env.h b/custom_ops/gpu_ops/env.h index c7db21ba8f..8f97c7bbab 100644 --- a/custom_ops/gpu_ops/env.h +++ b/custom_ops/gpu_ops/env.h @@ -62,3 +62,12 @@ inline bool get_mla_use_tensorcore() { mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env)); return mla_use_tensorcore != 0 ? true : false; } +inline int get_mla_dec_chunk_size(int bsz) { + static const char* mla_dec_chunk_size_env = + std::getenv("FLAGS_mla_dec_chunk_size"); + static const int mla_dec_chunk_size = + mla_dec_chunk_size_env == nullptr + ? -1 + : std::stoi(std::string(mla_dec_chunk_size_env)); + return bsz > 1 ? mla_dec_chunk_size : 64; +} diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 468aff1fc4..3b052ee521 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -557,3 +557,11 @@ inline int GetSMVersion() { return sm_version; } + +inline bool GetMlaUseTensorcore() { + static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore(); + static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false; + const bool mla_use_tensorcore = + flags_mla_use_tensorcore && enable_mla_tensorcore; + return mla_use_tensorcore; +} \ No newline at end of file diff --git a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu index f7d4b8ae27..740a63f00c 100644 --- a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu +++ b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu @@ -79,6 +79,7 @@ void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, const int num_blocks_x, + const int chunk_size, const int max_seq_len, const int max_dec_len, const float softmax_scale, @@ -97,7 +98,7 @@ void BatchMLAWithPagedKVCacheKernel( const auto q_head_num = meta_data.q_num_heads; const auto max_block_num_per_seq = meta_data.max_blocks_per_seq; const auto max_block_num = bsz * max_block_num_per_seq; - const uint32_t chunk_size = get_max_partition_size(bsz); + // const uint32_t chunk_size = get_max_partition_size(bsz); int q_head_dim = meta_data.head_dims; @@ -185,6 +186,7 @@ template void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, const int num_blocks_x, + const int chunk_size, const int max_seq_len, const int max_dec_len, const float softmax_scale, @@ -219,6 +221,7 @@ template void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, const int num_blocks_x, + const int chunk_size, const int max_seq_len, const int max_dec_len, const float softmax_scale, diff --git a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h index 97fffe39dc..afd16e2ea8 100644 --- a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h +++ b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h @@ -56,6 +56,7 @@ void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, const int num_blocks_x, + const int chunk_size, const int max_seq_len, const int max_dec_len, const float softmax_scale, diff --git a/custom_ops/gpu_ops/multi_head_latent_attention.cu b/custom_ops/gpu_ops/multi_head_latent_attention.cu index 98a61e8385..2112b2e704 100644 --- a/custom_ops/gpu_ops/multi_head_latent_attention.cu +++ b/custom_ops/gpu_ops/multi_head_latent_attention.cu @@ -38,6 +38,7 @@ std::vector MultiHeadLatentAttentionKernel( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& decoder_chunk_size_cpu, const paddle::Tensor& max_enc_len_this_time, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, @@ -67,6 +68,7 @@ std::vector MultiHeadLatentAttentionKernel( int decoder_num_blocks_data = decoder_num_blocks_cpu.data()[0]; int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; int max_len_kv_data = max_len_kv.data()[0]; + int chunk_size = decoder_chunk_size_cpu.data()[0]; const bool mla_use_tensorcore = get_mla_use_tensorcore(); auto sm_version = GetSMVersion(); @@ -105,6 +107,7 @@ std::vector MultiHeadLatentAttentionKernel( decoder_num_blocks, cache_quant_type_str, decoder_num_blocks_data, + chunk_size, max_input_length, max_len_kv_data, softmax_scale, @@ -161,6 +164,7 @@ std::vector MultiHeadLatentAttention( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& decoder_chunk_size_cpu, const paddle::Tensor& max_enc_len_this_time, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, @@ -224,6 +228,7 @@ std::vector MultiHeadLatentAttention( decoder_tile_ids_per_batch, decoder_num_blocks, decoder_num_blocks_cpu, + decoder_chunk_size_cpu, max_enc_len_this_time, max_dec_len_this_time, max_len_kv, @@ -270,6 +275,7 @@ std::vector MultiHeadLatentAttention( decoder_tile_ids_per_batch, decoder_num_blocks, decoder_num_blocks_cpu, + decoder_chunk_size_cpu, max_enc_len_this_time, max_dec_len_this_time, max_len_kv, @@ -303,113 +309,6 @@ std::vector MultiHeadLatentAttention( } } -std::vector> MultiHeadLatentAttentionInferShape( - const std::vector& query_shape, - const std::vector& key_cache_shape, - const std::vector& value_cache_shape, - const std::vector& seq_lens_encoder_shape, - const std::vector& seq_lens_decoder_shape, - const std::vector& seq_lens_this_time_shape, - const std::vector& cu_seqlens_q_shape, - const std::vector& batch_id_per_token_shape, - const std::vector& block_tables_shape, - const std::vector& encoder_batch_ids_shape, - const std::vector& encoder_tile_ids_per_batch_shape, - const std::vector& encoder_num_blocks_shape, - const std::vector& kv_batch_ids_shape, - const std::vector& kv_tile_ids_per_batch_shape, - const std::vector& kv_num_blocks_shape, - const std::vector& decoder_batch_ids_shape, - const std::vector& decoder_tile_ids_per_batch_shape, - const std::vector& decoder_num_blocks_shape, - const std::vector& decoder_num_blocks_cpu_shape, - const std::vector& max_enc_len_this_time_shape, - const std::vector& max_dec_len_this_time_shape, - const std::vector& max_len_kv_shape, - const paddle::optional>& attn_mask_shape, - const paddle::optional>& query_bias_shape, - const paddle::optional>& query_out_scales_shape, - const paddle::optional>& cache_k_quant_scales_shape, - const paddle::optional>& cache_v_quant_scales_shape, - const paddle::optional>& cache_k_dequant_scales_shape, - const paddle::optional>& cache_v_dequant_scales_shape, - const paddle::optional>& cache_k_zp_shape, - const paddle::optional>& cache_v_zp_shape, - const paddle::optional>& out_linear_shifts_shape, - const paddle::optional>& out_linear_smooths_shape, - const std::string& compute_dtype, - const std::string& cache_quant_type_str, - const int nope_size, - const int max_input_length, - const float softmax_scale, - const float quant_max_bound, - const float quant_min_bound, - const float out_linear_in_scale, - const int speculate_max_draft_token_num, - const bool causal, - const bool speculate_decoder) { - const int token_num = query_shape[0]; - const int kv_num_heads = key_cache_shape[1]; - const int head_dim_qk = key_cache_shape[3]; - const int head_dim_v = nope_size; - const int q_hidden_size = query_shape[query_shape.size() - 1]; - const int num_heads = q_hidden_size / head_dim_qk; - return {{token_num, num_heads * head_dim_v}}; -} - -std::vector MultiHeadLatentAttentionInferDtype( - const paddle::DataType& query_dtype, - const paddle::DataType& key_cache_dtype, - const paddle::DataType& value_cache_dtype, - const paddle::DataType& seq_lens_encoder_dtype, - const paddle::DataType& seq_lens_decoder_dtype, - const paddle::DataType& seq_lens_this_time_dtype, - const paddle::DataType& cu_seqlens_q_dtype, - const paddle::DataType& batch_id_per_token_dtype, - const paddle::DataType& block_tables_dtype, - const paddle::DataType& encoder_batch_ids_dtype, - const paddle::DataType& encoder_tile_ids_per_batch_dtype, - const paddle::DataType& encoder_num_blocks_dtype, - const paddle::DataType& kv_batch_ids_dtype, - const paddle::DataType& kv_tile_ids_per_batch_dtype, - const paddle::DataType& kv_num_blocks_dtype, - const paddle::DataType& decoder_batch_ids_dtype, - const paddle::DataType& decoder_tile_ids_per_batch_dtype, - const paddle::DataType& decoder_num_blocks_dtype, - const paddle::DataType& decoder_num_blocks_cpu_dtype, - const paddle::DataType& max_enc_len_this_time_dtype, - const paddle::DataType& max_dec_len_this_time_dtype, - const paddle::DataType& max_len_kv_dtype, - const paddle::optional& attn_mask_dtype, - const paddle::optional& query_bias_dtype, - const paddle::optional& query_out_scales_dtype, - const paddle::optional& cache_k_quant_scales_dtype, - const paddle::optional& cache_v_quant_scales_dtype, - const paddle::optional& cache_k_dequant_scales_dtype, - const paddle::optional& cache_v_dequant_scales_dtype, - const paddle::optional& cache_k_zp_dtype, - const paddle::optional& cache_v_zp_dtype, - const paddle::optional& out_linear_shifts_dtype, - const paddle::optional& out_linear_smooths_dtype, - const std::string& compute_dtype, - const std::string& cache_quant_type_str, - const int nope_size, - const int max_input_length, - const float softmax_scale, - const float quant_max_bound, - const float quant_min_bound, - const float out_linear_in_scale, - const int speculate_max_draft_token_num, - const bool causal, - const bool speculate_decoder) { - if (compute_dtype == "bf16") { - return {paddle::DataType::BFLOAT16}; - } else if (compute_dtype == "fp16") { - return {paddle::DataType::FLOAT16}; - } else { - PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); - } -} PD_BUILD_STATIC_OP(multi_head_latent_attention) .Inputs({"query", @@ -431,6 +330,7 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention) "decoder_tile_ids_per_batch", "decoder_num_blocks", "decoder_num_blocks_cpu", + "decoder_chunk_size_cpu", "max_enc_len_this_time", "max_dec_len_this_time", "max_len_kv", @@ -457,6 +357,4 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention) "speculate_max_draft_token_num: int", "causal: bool", "speculate_decoder: bool"}) - .SetKernelFn(PD_KERNEL(MultiHeadLatentAttention)) - .SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype)); + .SetKernelFn(PD_KERNEL(MultiHeadLatentAttention)); diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 68a469e795..2ae0cc898b 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -195,6 +195,8 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, + metadata.decoder_num_blocks, + metadata.decoder_chunk_size_cpu, metadata.max_len_kv, ) = get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, @@ -210,6 +212,12 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.block_size, self.speculate_max_draft_token_num + 1, ) + print("metadata.kv_batch_ids",metadata.kv_batch_ids) + print("metadata.kv_tile_ids_per_batch",metadata.kv_tile_ids_per_batch) + print("metadata.kv_num_blocks",metadata.kv_num_blocks) + print("metadata.decoder_num_blocks",metadata.decoder_num_blocks) + print("metadata.decoder_chunk_size_cpu",metadata.decoder_chunk_size_cpu) + print("metadata.max_len_kv",metadata.max_len_kv) # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] @@ -369,6 +377,7 @@ def forward_decode( forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, forward_meta.decoder_num_blocks_cpu, + metadata.decoder_chunk_size_cpu, metadata.max_enc_len_this_time, metadata.max_dec_len_this_time, metadata.max_len_kv, @@ -450,6 +459,11 @@ def forward_mixed( causal=self.causal, **self.flash_attn_kwargs, )[0] + print("mix_FA3+++++++++++++++++++fmha_out",fmha_out) + print("nan || -inf FA3 :",paddle.isfinite(fmha_out).all()) + is_finite = paddle.isfinite(fmha_out).all().item() + is_normal = (paddle.abs(fmha_out) < 1e6).all().item() + print("nan/inf/大异常值:", is_finite and is_normal) return fmha_out @@ -488,8 +502,9 @@ def forward_mixed( metadata.kv_num_blocks, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, + metadata.decoder_num_blocks, forward_meta.decoder_num_blocks_cpu, - forward_meta.decoder_num_blocks_cpu, + metadata.decoder_chunk_size_cpu, metadata.max_enc_len_this_time, metadata.max_dec_len_this_time, metadata.max_len_kv, @@ -516,5 +531,11 @@ def forward_mixed( True, # causal speculate_decoder, ) + print("nan || -inf MLA :",paddle.isfinite(fmha_out).all()) + is_finite = paddle.isfinite(fmha_out).all().item() + is_normal = (paddle.abs(fmha_out) < 1e6).all().item() + print("nan/inf/大异常值:", is_finite and is_normal) + print("mix_MLA=====================fmha_out",fmha_out) + return fmha_out diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index dd57b52593..0898e1f64f 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -49,6 +49,8 @@ def get_block_shape_and_split_kv_block( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, + decoder_num_blocks, + decoder_chunk_size_cpu, max_len_kv_cpu, ) = get_block_shape_and_split_kv_block_cuda( seq_lens_encoder, @@ -71,6 +73,8 @@ def get_block_shape_and_split_kv_block( kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, + decoder_num_blocks, + decoder_chunk_size_cpu, max_len_kv_cpu, ) else: diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index b65925be21..5b181ef8bc 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -401,6 +401,7 @@ def forward( fmha_out = fmha_out + fmha_out_decode output = self.o_proj(fmha_out) + print("out__attn______________",output) return output def load_state_dict(self, state_dict):