diff --git a/.github/workflows/_logprob_test_linux.yml b/.github/workflows/_logprob_test_linux.yml index 8ca3c7d7f64..82a1ec68f87 100644 --- a/.github/workflows/_logprob_test_linux.yml +++ b/.github/workflows/_logprob_test_linux.yml @@ -163,7 +163,7 @@ jobs: -d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}" set +e rm -rf ./baseline_output - cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output + cp -r baseline_1131/ERNIE-4.5-0.3B-Paddle ./baseline_output LOGPROB_EXIT_CODE=0 python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$? echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env diff --git a/custom_ops/gpu_ops/append_attention/attention_func.cuh b/custom_ops/gpu_ops/append_attention/attention_func.cuh new file mode 100644 index 00000000000..583ba728a8c --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/attention_func.cuh @@ -0,0 +1,1035 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = 0.f; + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if constexpr (std::is_same::value) { + m[fx][j] = -5e4f; + } else if constexpr (std::is_same::value) { + m[fx][j] = -3.0e+30f; + } + d[fx][j] = 1.f; + } + } +} + +template +__device__ __forceinline__ void load_block_table_per_chunk( + const int32_t* block_table_chunk_start, + int32_t* block_table_smem, + uint32_t chunk_start, + uint32_t chunk_end, + uint32_t tid, + uint32_t wid) { + uint32_t len = chunk_end / BLOCK_SIZE - chunk_start / BLOCK_SIZE; + for (uint32_t i = 0; i < div_up(len, 128); i++) { + uint32_t offset = (wid * kWarpSize + tid) * i; + if (offset <= len) { + block_table_smem[offset] = block_table_chunk_start[offset]; + } + } +} + +// load q from global memory to shared memory +template +__device__ __forceinline__ void load_q_global_smem_multi_warps( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset(ty * 4 + tx / 8, + tx % 8); // 4 * 64 + + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { + const int offset = i * 1024 + ty * 256 + tx * 8; + Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store(tmp_vec, reinterpret_cast(q_smem->base) + offset); + } +} + +template +__device__ __forceinline__ void produce_k_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void produce_v_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); + +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= 2 * num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + 2 * num_frags_z * num_elems_per_128b(); + } + kv_idx += block_size; + } + *smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize; +} + +template +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, + const int* block_table_now, + const T* cache_kv_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; + // 1 warp 32 tokens + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } +} + +template +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + // 1 warp 32 tokens + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } +} + +template +__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + const T* cache_k_scale, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); + + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 2 * 2; +} + +template +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint32_t attn_mask_len, + float (*s_frag)[num_frags_z][8], + const int* mask_offset = nullptr, + const int sliding_window = 0) { + const uint32_t tx = threadIdx.x; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 + + 8 * ((reg_id % 4) / 2)) / + group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + bool out_of_boundary; + if (mask_offset) { + out_of_boundary = q_idx < qo_len + ? (kv_idx >= mask_offset[q_idx * 2 + 1] || + kv_idx < mask_offset[q_idx * 2]) + : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } + } + + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t j_id = j * 2; + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_tmp = s_frag[fx][fz] + j_id; + float m_local = max(max(s_frag_tmp[0], s_frag_tmp[1]), + max(s_frag_tmp[4], s_frag_tmp[5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32)); + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32)); + float o_scale = expf(m_prev - m[fx][j]); + d[fx][j] *= o_scale; + float2 fp2_scale = make_float2(o_scale, o_scale); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // o_frag[fx][fy][j * 2 + 0] *= o_scale; + // o_frag[fx][fy][j * 2 + 1] *= o_scale; + // o_frag[fx][fy][j * 2 + 4] *= o_scale; + // o_frag[fx][fy][j * 2 + 5] *= o_scale; + + float2* o_frag_ptr = reinterpret_cast(o_frag[fx][fy] + j_id); + // printf("fp2_len:%d, %d", sizeof(o_frag_ptr[0]), sizeof(fp2_scale)); + o_frag_ptr[0] = fast_float2_mul(o_frag_ptr[0], fp2_scale); + o_frag_ptr[2] = fast_float2_mul(o_frag_ptr[2], fp2_scale); + } + float tmp_m = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_ptr = s_frag[fx][fz] + j_id; + s_frag_ptr[0] = __expf(s_frag_ptr[0] - tmp_m); + s_frag_ptr[1] = __expf(s_frag_ptr[1] - tmp_m); + s_frag_ptr[4] = __expf(s_frag_ptr[4] - tmp_m); + s_frag_ptr[5] = __expf(s_frag_ptr[5] - tmp_m); + // s_frag[fx][fz][j * 2 + 0] = + // __expf(s_frag[fx][fz][j * 2 + 0] - m[fx][j]); + // s_frag[fx][fz][j * 2 + 1] = + // __expf(s_frag[fx][fz][j * 2 + 1] - m[fx][j]); + // s_frag[fx][fz][j * 2 + 4] = + // __expf(s_frag[fx][fz][j * 2 + 4] - m[fx][j]); + // s_frag[fx][fz][j * 2 + 5] = + // __expf(s_frag[fx][fz][j * 2 + 5] - m[fx][j]); + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T* cache_v_scale) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } + } + } else { + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); + } + } + } + *v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid) { + float2* smem_md = reinterpret_cast( + md_smem + num_frags_x * num_frags_y * 1024); // 4 * 32 * 8 +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * 32 + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* md_smem_start = + (float2*)(md_smem + + ((wid * num_frags_x + fx) * num_frags_y + fy) * 32 * 8 + + tid * 2); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + md_smem_start[i * 32] = ((float2*)(&o_frag[fx][fy][0]))[i]; + } + // *(reinterpret_cast( + // md_smem + (((wid * num_frags_x + fx) * num_frags_y + fy) * 32 + + // tid) * + // 8)) = + // *(reinterpret_cast(&o_frag[fx][fy][0])); + // *(reinterpret_cast( + // md_smem + + // (((wid * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) * 8 + + // 4)) = + // *(reinterpret_cast(&o_frag[fx][fy][4])); + } + } + __syncthreads(); + float o_scale[4][num_frags_x][2]; + + // deal md/scale +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + // d_new = d_prev * expf(m_prev - m_new) + md.y * expf(md.x - m_new); + d_new = fmaf(d_prev, expf(m_prev - m_new), md.y * expf(md.x - m_new)); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid]; + o_scale[i][fx][j] = expf(md.x - m_new); + } + m[fx][j] = m_new; + d[fx][j] = d_new; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // num_warps * 32 * 8 each time + // AlignedVector o_new_fp2; + float2* o_new_fp2 = reinterpret_cast(&o_frag[fx][fy][0]); + // float2* o_new_fp2 = reinterpret_cast(&o_new[0]); +#pragma + for (uint32_t o_id = 0; o_id < 4; ++o_id) { + o_new_fp2[o_id] = make_float2(0.f, 0.f); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + // AlignedVector oi; + AlignedVector oi_fp2; + float2* md_smem_start = + (float2*)(md_smem + + ((i * num_frags_x + fx) * num_frags_y + fy) * 32 * 8 + + tid * 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + oi_fp2[reg_id] = md_smem_start[reg_id * 32]; + } +#pragma unroll + for (uint32_t reg_fp2_id = 0; reg_fp2_id < 4; ++reg_fp2_id) { + float o_scale_fp2_tmp = o_scale[i][fx][reg_fp2_id % 2]; + o_new_fp2[reg_fp2_id] = + fast_float2_fma(oi_fp2[reg_fp2_id], + make_float2(o_scale_fp2_tmp, o_scale_fp2_tmp), + o_new_fp2[reg_fp2_id]); + } + } + } + } +} + +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / d[fx][j]; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_multi_warps( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 16 / sizeof(T); + // [num_warps * num_frags_x * 16, num_frags_y * 16] + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); + + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); + + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + if (n_offset < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ void merge( + const AlignedVector& other_o, float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } + + __device__ __forceinline__ void normalize(float current_sink) { + const T d_t = static_cast(d + __expf(current_sink - m)); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +template +__global__ void merge_chunks_kernel( + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + const int* __restrict__ chunk_size_ptr, + T* __restrict__ out, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int head_dim, + const int token_num, + const int max_tokens_per_batch = 5) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int hid = blockIdx.y; + __shared__ T smem[bdy * HEAD_DIM]; + __shared__ float md_smem[bdy * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == -1) { + continue; + } + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if (num_chunks_this_seq == 1) { + if (ty == 0) { + uint32_t offset = ((bid * max_tokens_per_batch + local_seq_id) * + num_chunks * num_heads + + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + Store( + load_vec, + &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + } else { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + + offset = + ((bid * max_tokens_per_batch + local_seq_id) * num_chunks + i) * + num_heads + + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * + num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = expf(m_prev - m), scale2 = expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + // store ty res + Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + md_smem[2 * ty] = m; + md_smem[2 * ty + 1] = d; + __syncthreads(); + if (ty == 0) { + // merge bdy + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); + } + Store( + out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + } + __syncthreads(); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} diff --git a/custom_ops/gpu_ops/append_attention/config_for_attention.cu b/custom_ops/gpu_ops/append_attention/config_for_attention.cu new file mode 100644 index 00000000000..bacb519765e --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/config_for_attention.cu @@ -0,0 +1,359 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "cute/tensor.hpp" +#include "helper.h" +#include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#endif +#include "utils.cuh" + +template +__global__ void GetMaxLenKernel(const int *seq_lens_decoder, + const int *seq_lens_this_time, + const int *seq_lens_encoder, + int *max_lens, + const int batch_size) { + const int tid = threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int max_len_this_time_this_thread = 0; + int max_len_encoder_this_thread = 0; + int max_len_decoder_this_thread = 0; + int max_len_this_thread = 0; + int max_just_dec_len_this_thread = 0; + int max_len_kv_this_thread = 0; + for (int i = tid; i < batch_size; i += blockDim.x) { + const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; + max_len_this_time_this_thread = + max(seq_len_this_time, max_len_this_time_this_thread); + max_len_encoder_this_thread = + max(seq_lens_encoder[i], max_len_encoder_this_thread); + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_this_thread = + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); + max_just_dec_len_this_thread = + max(max_just_dec_len_this_thread, max_just_dec_len_now); + + if (seq_len_decoder == 0) continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); + } + int total_max_len_this_time = + BlockReduce(temp_storage) + .Reduce(max_len_this_time_this_thread, MaxOp()); + int total_max_len_encoder = + BlockReduce(temp_storage) + .Reduce(max_len_encoder_this_thread, MaxOp()); + int total_max_len_decoder = + BlockReduce(temp_storage) + .Reduce(max_len_decoder_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total_just_dec = BlockReduce(temp_storage) + .Reduce(max_just_dec_len_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); + if (tid == 0) { + max_lens[0] = total_max_len_this_time; + max_lens[1] = total_max_len_encoder; + max_lens[2] = total_max_len_decoder; + max_lens[3] = total; + max_lens[4] = total_just_dec; + max_lens[5] = total_max_len_kv; + } +} + +template +__global__ void config_decode_attn(const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ block_indices, + int *__restrict__ num_blocks, + int *__restrict__ chunk_size, + const int bsz, + const int group_size, + const int kv_num_heads, + const int q_tile_size, + const int max_tokens_per_batch, + const int config_gridx) { + // one block one warp + const int tid = threadIdx.x, wid = threadIdx.y; + const uint32_t warp_size = blockDim.x; + __shared__ int num_block_all_shared[block_size]; + + const int lane_id = tid + wid * warp_size; + int cur_chunk_size = min_chunk_size * (lane_id + 1); + + // calculate num_block_all + int num_block_all = 0; + for (int bid = 0; bid < bsz; bid++) { + if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) { + continue; + } + + int token_num_cur_batch = seq_lens_this_time[bid]; + int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + int q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + num_block_all_shared[lane_id] = num_block_all; + __syncthreads(); + + // search optimal chunk_size + int chunk_size_best; + int num_block_all_best; + if (tid == 0 && wid == 0) { + if (num_block_all_shared[0] <= config_gridx) { + chunk_size_best = min_chunk_size; + num_block_all_best = num_block_all_shared[0]; + } else if (num_block_all_shared[block_size - 1] >= config_gridx) { + chunk_size_best = min_chunk_size * block_size; + num_block_all_best = num_block_all_shared[block_size - 1]; + for (int i = block_size - 1; i >= 0; i--) { + if (num_block_all_shared[i] > num_block_all_best) { + break; + } + chunk_size_best = min_chunk_size * (i + 1); + } + } else { + chunk_size_best = min_chunk_size; + num_block_all_best = num_block_all_shared[0]; + for (int i = block_size - 1; i >= 0; i--) { + if (num_block_all_shared[i] > config_gridx) { + break; + } + if (num_block_all_shared[i] > num_block_all_best) { + num_block_all_best = num_block_all_shared[i]; + chunk_size_best = min_chunk_size * (i + 1); + } + } + } + num_blocks[0] = num_block_all_best; + chunk_size[0] = chunk_size_best; + } + + __syncthreads(); + if (wid == 0) { + chunk_size_best = __shfl_sync(0xffffffff, chunk_size_best, 0); + + // one block one warp + int prev_offset = 0; + // loop on warp tile:[base, base+32) + for (int base = 0; base < bsz; base += warp_size) { + const int bid = base + tid; + int q_tile_num = 0; + int kv_chunk_num = 0; + + // calculate loop_times for bid + int num_block_all = 0; + if (bid < bsz) { + int token_num_cur_batch = seq_lens_this_time[bid]; + if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { + token_num_cur_batch = 0; + } + int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_best); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + + // prefix sum for each lane, get the start offset in this tile + // inclusive scan + int x = num_block_all; + for (int offset = 1; offset < warp_size; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (tid >= offset) x += y; + } + // exclusive prefix sum + int bid_offset = x - num_block_all; + int tile_sum = __shfl_sync(0xffffffff, x, warp_size - 1); + + // write batch_ids and tile_ids_per_batch + if (bid < bsz && num_block_all > 0) { + int write_base = prev_offset + bid_offset; + for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) { + for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num; kv_chunk_id++) { + for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) { + int idx = + write_base * 4 + + ((kv_head_id * kv_chunk_num + kv_chunk_id) * q_tile_num + + q_tile_id) * + 4; + block_indices[idx] = bid; + block_indices[idx + 1] = kv_head_id; + block_indices[idx + 2] = kv_chunk_id; + block_indices[idx + 3] = q_tile_id; + } + } + } + } + // for next warp tile + prev_offset += tile_sum; + } + } +} + +void ConfigForAttention( + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, + paddle::Tensor &block_indices, // Inplace, shape:[block_num,4], block's + // indices with 4 dimension[batch_idx, + // kv_head_idx, kv_chunk_idx, q_tile_idx] + paddle::Tensor &num_blocks, // Inplace + paddle::Tensor &chunk_size, // Inplace + paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + auto stream = seq_lens_encoder.stream(); + int bsz = seq_lens_this_time.shape()[0]; + + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data + // is only for branching in attention. +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); + int max_just_dec_len_this_time = max_len_cpu_ptr[4]; + + const uint32_t block_indices_ele_num = block_indices.size(); + + // decoder + if (max_just_dec_len_this_time > 0) { + CUDA_CHECK(cudaMemsetAsync(block_indices.data(), + 0, + block_indices_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK( + cudaMemsetAsync(num_blocks.data(), 0, sizeof(int32_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(chunk_size.data(), 0, sizeof(int32_t), stream)); + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); + const int config_gridx = sm_cout * 2; + + // 选择最优的q_tile_size + int q_tile_size = 32; + if (group_size * max_tokens_per_batch <= 16) { + q_tile_size = 16; + } + dim3 blocks(32, 4); + if (cache_quant_type == "cache_int4_zp") { + config_decode_attn<256, 128> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices.data(), + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } else { + config_decode_attn<128, 128> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices.data(), + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } + } +} + +std::vector> ConfigForAttentionInferShape( + const std::vector &seq_lens_encoder_shape, + const std::vector &seq_lens_decoder_shape, + const std::vector &seq_lens_this_time_shape, + const std::vector &num_blocks_shape, + const std::vector &chunk_size_shape, + const std::vector &max_len_tensor_cpu_shape, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +std::vector ConfigForAttentionInferDtype( + const paddle::DataType &seq_lens_encoder_dtype, + const paddle::DataType &seq_lens_decoder_dtype, + const paddle::DataType &seq_lens_this_time_dtype, + const paddle::DataType &num_blocks_dtype, + const paddle::DataType &chunk_size_dtype, + const paddle::DataType &max_len_tensor_cpu_dtype, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +PD_BUILD_STATIC_OP(config_for_attention) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "block_indices", + "num_blocks", + "chunk_size", + "max_len_tensor_cpu", + }) + .Outputs({ + + }) + .Attrs({"cache_quant_type: std::string", + "group_size: int", + "kv_num_heads: int", + "max_tokens_per_batch: int"}) + .SetKernelFn(PD_KERNEL(ConfigForAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(ConfigForAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConfigForAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh b/custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh new file mode 100644 index 00000000000..ff84e1cd3f6 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh @@ -0,0 +1,124 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +template +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template +CUtensorMap makeTensorMapForKVCache(T const* addr, + uint32_t block_num, + uint32_t kv_num_head, + uint32_t second_size, + uint32_t last_size) { + CUtensorMap tensorMap{}; + + uint32_t elem_bytes = sizeof(T); + + uint32_t const last_size_bytes = elem_bytes * last_size; + // VLLM Layout + CUtensorMapDataType data_dtype = cu_tensor_map_type_traits::type; + constexpr uint32_t rank = 4; + uint64_t global_dims[] = {last_size, second_size, kv_num_head, block_num}; + uint64_t global_strides[] = {last_size_bytes, + second_size * last_size_bytes, + kv_num_head * second_size * last_size_bytes}; + + uint32_t box_dims[] = {last_size, second_size, 1, 1}; + uint32_t elem_strides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (last_size_bytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache last_size"); + } + }(); + CUresult res = cuTensorMapEncodeTiled( + &tensorMap, + data_dtype, + rank, + reinterpret_cast(const_cast(addr)), + global_dims, + global_strides, + box_dims, + elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + switch (res) { + case CUDA_SUCCESS: + printf("CUDA_SUCCESS!\n"); + break; + case CUDA_ERROR_INVALID_VALUE: + printf("CUDA_ERROR_INVALID_VALUE\n"); + break; + case CUDA_ERROR_OUT_OF_MEMORY: + printf("CUDA_ERROR_OUT_OF_MEMORY\n"); + break; + case CUDA_ERROR_NOT_INITIALIZED: + printf("CUDA_ERROR_NOT_INITIALIZED\n"); + break; + case CUDA_ERROR_DEINITIALIZED: + printf("CUDA_ERROR_DEINITIALIZED\n"); + break; + case CUDA_ERROR_PROFILER_DISABLED: + printf("CUDA_ERROR_PROFILER_DISABLED\n"); + break; + default: + throw std::runtime_error("unsupported res!"); + } + + return tensorMap; +} diff --git a/custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh new file mode 100644 index 00000000000..c0695927e04 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh @@ -0,0 +1,849 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +// #include "cu_tensor_map.cuh" +#include "attention_func.cuh" + +template +void print_params(AttentionParams const params) { + printf("max_model_len: %d\n", params.max_model_len); + printf("max_kv_len: %d\n", params.max_kv_len); + printf("max_blocks_per_seq: %d\n", params.max_blocks_per_seq); + printf("softmax_scale: %f\n", params.softmax_scale); + printf("quant_max_bound: %f\n", params.quant_max_bound); + printf("quant_min_bound: %f\n", params.quant_min_bound); + printf("max_tokens_per_batch: %d\n", params.max_tokens_per_batch); + printf("attn_mask_len: %d\n", params.attn_mask_len); + printf("sliding_window: %d\n", params.sliding_window); + printf("q_num_heads: %d\n", params.q_num_heads); + printf("kv_num_heads: %d\n", params.kv_num_heads); + printf("max_num_chunks: %d\n", params.max_num_chunks); + printf("max_tile_q: %d\n", params.max_tile_q); + printf("batch_size: %d\n", params.batch_size); +} + +// __launch_bounds__( +// NUM_THREADS_PER_BLOCK, 1 +// ) +template +__global__ void decode_append_attention_c8_kernel( + const __grid_constant__ AttentionParams params + // const __grid_constant__ CUtensorMap key_tensor_map, + // const __grid_constant__ CUtensorMap value_tensor_map +) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // 内存分配 + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + smem_t k_scale_smem; + smem_t v_scale_smem; + T *k_smem_scale_ptr = nullptr; + T *v_smem_scale_ptr = nullptr; + + // TMA + // #pragma nv_diag_suppress static_var_with_dynamic_init + // __shared__ __align__(128) barrier bar[4]; + // if(tid == 0 && wid == 0) { + // for (int i = 0; i < 4; ++i) { + // init(&(bar[i]), blockDim.x * blockDim.y); + // cde::fence_proxy_async_shared_cta(); + // } + // } + // __syncthreads(); + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + // block_indices: shape [block_num,4], block's indices with 4 + // dimension[batch_idx, kv_head_idx, kv_chunk_idx, q_tile_idx] + int batch_idx = params.block_indices[lane_idx * 4]; + int kv_head_idx = params.block_indices[lane_idx * 4 + 1]; + int chunk_idx = params.block_indices[lane_idx * 4 + 2]; + int tile_idx = params.block_indices[lane_idx * 4 + 3]; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = params.seq_lens_q[batch_idx]; + if (q_len <= 0) { + continue; + } + const int *block_table_now = + params.block_table + batch_idx * params.max_blocks_per_seq; + + T cache_k_scale_reg[IsDynamicC8 + ? num_frags_z * 2 + : (is_scale_channel_wise ? num_frags_y * 4 : 1)]; + T cache_v_scale_reg[IsDynamicC8 + ? num_frags_z * 4 + : (is_scale_channel_wise ? num_frags_y * 2 : 1)]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = params.cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = params.cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = params.cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = params.cache_v_scale[kv_head_idx]; + } + } + const uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = params.seq_lens_kv[batch_idx]; + + if (kv_len <= 0) { + continue; + } + kv_len += q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + continue; + } + + // 相关const变量 + // barrier::arrival_token tokens[4]; + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + const uint32_t q_n_stride = params.q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = + (params.q_num_heads + params.kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = params.kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + T *o_base_ptr_T = nullptr; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = params.cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = params.qkv + q_offset; + + o_base_ptr_T = params.tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int *mask_offset_this_seq = + params.mask_offset ? params.mask_offset + q_start_seq_id * 2 : nullptr; + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + // if(blockIdx.x == 0 && tid == 0 && wid == 0) { + // printf("load q end!\n"); + // } + // __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, params.softmax_scale); + + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_idx * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : params.mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + params.cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + // #pragma unroll 1 + // for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + // int block_id = __ldg(&block_table_now[(kv_idx_base + kv_i * 64) / + // BLOCK_SIZE]); if (block_id < 0) block_id = 0; if (tid == 0 && wid + // == 0) { + // // 发起 TMA 四维异步拷贝操作 + // cde::cp_async_bulk_tensor_4d_global_to_shared((void*)(smem + + // num_frags_x * 16 * HEAD_DIM * sizeof(T) + kv_i * (NUM_WARP_KV * + // 16 * HEAD_DIM * sizeof(CacheT))), &key_tensor_map, 0, 0, + // kv_head_idx, block_id, bar[kv_i]); + // // 设置同步等待点,指定需要等待的拷贝完成的字节数。 + // tokens[kv_i] = cuda::device::barrier_arrive_tx(bar[kv_i], 1, + // NUM_WARP_KV * 16 * HEAD_DIM * sizeof(CacheT)); + // // printf("t0 barrier_arrive_tx end\n"); + // } else { + // // Other threads just arrive. + // tokens[kv_i] = bar[kv_i].arrive(); + // // printf("t1 arrive end token:%d\n", token); + // } + // } + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + params.cache_k_scale, + kv_idx_base, + params.kv_num_heads, + kv_head_idx, + chunk_end); + // commit_group(); + } + commit_group(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + params.cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + // #pragma unroll 1 + // for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + // int block_id = __ldg(&block_table_now[(kv_idx_base + kv_i * 64) / + // BLOCK_SIZE]); if (block_id < 0) block_id = 0; if (tid == 0 && wid + // == 0) { + // // 发起 TMA 四维异步拷贝操作 + // cde::cp_async_bulk_tensor_4d_global_to_shared(smem + num_frags_x + // * 16 * HEAD_DIM * sizeof(T) + + // NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) + + // kv_i * (NUM_WARP_KV * 16 * HEAD_DIM * sizeof(CacheT)), + // &value_tensor_map, 0, 0, kv_head_idx, block_id, bar[2 + + // kv_i]); + // // 设置同步等待点,指定需要等待的拷贝完成的字节数。 + // // printf("bit:%d", NUM_WARP_KV * 16 * HEAD_DIM * + // sizeof(CacheT)); tokens[2 + kv_i] = + // cuda::device::barrier_arrive_tx(bar[2 + kv_i], 1, NUM_WARP_KV * + // 16 * HEAD_DIM * sizeof(CacheT)); + // } else { + // // Other threads just arrive. + // tokens[2 + kv_i] = bar[2 + kv_i].arrive(); + // } + // } + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + params.cache_v_scale, + kv_idx_base, + params.kv_num_heads, + kv_head_idx, + chunk_end); + // commit_group(); + } + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg(k_smem_scale_ptr, + cache_k_scale_reg); + } + + // s = qk + // #pragma unroll 1 + // for(uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + // bar[kv_i].wait(std::move(tokens[kv_i])); + // } + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(params.attn_mask + ? params.attn_mask + batch_idx * + params.attn_mask_len * + params.attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + // const uint32_t ori_kv_idx_base = kv_idx_base; + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + params.cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + // if (iter < num_iterations - 1) { + // #pragma unroll 1 + // for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + // int block_id = __ldg(&block_table_now[(kv_idx_base + kv_i * + // 64) / BLOCK_SIZE]); if (block_id < 0) block_id = 0; if (tid + // == 0 && wid == 0) { + // // 发起 TMA 四维异步拷贝操作 + // cde::cp_async_bulk_tensor_4d_global_to_shared(smem + + // num_frags_x * 16 * HEAD_DIM * sizeof(T) + kv_i * + // (NUM_WARP_KV * 16 * HEAD_DIM * sizeof(CacheT)), + // &key_tensor_map, 0, 0, kv_head_idx, block_id, bar[kv_i]); + // // 设置同步等待点,指定需要等待的拷贝完成的字节数。 + // tokens[kv_i] = cuda::device::barrier_arrive_tx(bar[kv_i], + // 1, NUM_WARP_KV * 16 * HEAD_DIM * sizeof(CacheT)); + // } else { + // // Other threads just arrive. + // tokens[kv_i] = bar[kv_i].arrive(); + // } + // } + // } + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async( + k_scale_smem, + block_table_now, + params.cache_k_scale, + kv_idx_base, + params.kv_num_heads, + kv_head_idx, + chunk_end); + // commit_group(); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg(v_smem_scale_ptr, + cache_v_scale_reg); + } + + // #pragma unroll 1 + // for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + // bar[2 + kv_i].wait(std::move(tokens[2 + kv_i])); + // } + // compute sfm * v + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + params.cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + // if (iter < num_iterations - 1) { + // #pragma unroll 1 + // for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + // int block_id = __ldg(&block_table_now[(kv_idx_base + kv_i * + // 64) / BLOCK_SIZE]); if (block_id < 0) block_id = 0; if (tid + // == 0 && wid == 0) { + // // 发起 TMA 四维异步拷贝操作 + // cde::cp_async_bulk_tensor_4d_global_to_shared(smem + + // num_frags_x * 16 * HEAD_DIM * sizeof(T) + + // NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * + // sizeof(CacheT) + kv_i * (NUM_WARP_KV * 16 * HEAD_DIM * + // sizeof(CacheT)), &value_tensor_map, 0, 0, kv_head_idx, + // block_id, bar[2 + kv_i]); + // // 设置同步等待点,指定需要等待的拷贝完成的字节数。 + // tokens[2 + kv_i] = cuda::device::barrier_arrive_tx(bar[2 + + // kv_i], 1, NUM_WARP_KV * 16 * HEAD_DIM * sizeof(CacheT)); + // } else { + // // Other threads just arrive. + // tokens[2 + kv_i] = bar[2 + kv_i].arrive(); + // } + // } + // } + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async( + v_scale_smem, + block_table_now, + params.cache_v_scale, + kv_idx_base, + params.kv_num_heads, + kv_head_idx, + chunk_end); + // commit_group(); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + // #pragma unroll 1 + // for (uint32_t i = 0; i < NUM_WARP_KV; ++i) { + // bar[i].wait(std::move(tokens[i])); + // } + merge_block_res( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + normalize_d(o_frag, d_frag); + } + // write o + // [num_frags_x, 16, num_frags_y, 16] + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + params.q_num_heads + + qo_head_idx; + params.tmp_m[offset] = m_frag[fx][j]; + params.tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +void DecodeAppendC8Attention(const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::Tensor &tmp_workspace, + const paddle::Tensor &tmp_m, + const paddle::Tensor &tmp_d, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &block_indices, + const paddle::Tensor &num_blocks, + const paddle::Tensor &chunk_size, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size_0 = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 32 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 32 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = decode_append_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = decode_append_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + // uint32_t chunk_size = static_cast(max_partition_size); + + const int max_num_chunks = div_up(max_seq_len, 128); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + // phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + // tmp_workspace = allocator->Allocate( + // phi::SizeOf(qkv.dtype()) * + // static_cast(max_tokens_per_batch * bsz * + // max_num_chunks * num_heads * HEAD_DIM)); + // tmp_m = allocator->Allocate( + // phi::SizeOf(paddle::DataType::FLOAT32) * + // static_cast(max_tokens_per_batch * bsz * + // max_num_chunks * num_heads)); + // tmp_d = allocator->Allocate( + // phi::SizeOf(paddle::DataType::FLOAT32) * + // static_cast(max_tokens_per_batch * bsz * + // max_num_chunks * num_heads)); + // } + // } + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = const_cast(cache_k.data()); + params.cache_v = const_cast(cache_v.data()); + params.cache_k_scale = + reinterpret_cast(const_cast(cache_k_scale.data())); + params.cache_v_scale = + reinterpret_cast(const_cast(cache_v_scale.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.quant_max_bound = quant_max_bound; + params.quant_min_bound = quant_min_bound; + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + // params.max_tile_q = div_up(GROUP_SIZE * max_tokens_per_batch, + // BLOCK_SHAPE_Q); + params.batch_size = meta_data.batch_size; + // params.num_blocks_x = num_blocks_x_cpu; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids( + sm_cout * + 2); // TODO(lizhenyun): tuning optimal gridx to while num_frags_x == 2 + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + // auto cache_k_dim = cache_k.dims(); + // CUtensorMap key_tensor_map = + // makeTensorMapForKVCache(cache_k.data(), + // cache_k.dims()[0], params.kv_num_heads, BLOCK_SIZE, HEAD_DIM); CUtensorMap + // value_tensor_map = + // makeTensorMapForKVCache(cache_v.data(), + // cache_v.dims()[0], params.kv_num_heads, HEAD_DIM, BLOCK_SIZE); + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE *)nullptr, + (NV_TYPE *)nullptr, + sinks + ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/append_attention/mem_util.cuh b/custom_ops/gpu_ops/append_attention/mem_util.cuh new file mode 100644 index 00000000000..18788858923 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/mem_util.cuh @@ -0,0 +1,389 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class SharedMemFillMode { kFillZero, kNoFill }; + +enum class PrefetchMode { kNoPrefetch, kPrefetch }; + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, + T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +__device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } +#else + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } + } +#endif +} + +template +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_32b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4)); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128, "num_bits must be 128"); + load_128b(smem_ptr, gmem_ptr); +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 64 || num_bits == 32, + "num_bits must be 128, 64 or 32."); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 64) { + pred_load_64b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 32) { + pred_load_32b(smem_ptr, gmem_ptr, predicate); + } +} + +using b32_t = uint32_t; +using b64_t = uint2; +using b128_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { + return sizeof(b128_t) / sizeof(T); +} + +struct smem_t { + // The base pointer. + b128_t* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + if constexpr (inv_stride <= 1) { + return i * stride + (j ^ (i % 8)); + } else { + return i / inv_stride * 8 + ((j + (i % inv_stride) * stride)) ^ + ((i / inv_stride) % 8); + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + if constexpr (row_stride == 2) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } else if constexpr (row_stride == 4) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_row(uint32_t offset) { + if constexpr (row_stride == 2) { + static_assert(step_size == 16 || step_size % 32 == 0, + "Unsupported step size"); + if constexpr (step_size == 16) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 32 == 0 + return offset + step_size * row_stride; + } + } else if constexpr (row_stride == 4) { + static_assert(step_size == 8 || step_size % 16 == 0, + "Unsupported step size"); + if constexpr (step_size == 8) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 16 == 0 + return offset + step_size * row_stride; + } + } else { + static_assert(step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_impl(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_trans_impl(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr, + bool predicate) { + b128_t* smem_ptr = base + offset; + pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr) { + b128_t* smem_ptr = base + offset; + load_128b(smem_ptr, + reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; diff --git a/custom_ops/gpu_ops/append_attention/mma_tensor_op.cuh b/custom_ops/gpu_ops/append_attention/mma_tensor_op.cuh new file mode 100644 index 00000000000..8662ee298d2 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/mma_tensor_op.cuh @@ -0,0 +1,296 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_i8i8i32( + int* C, // 8 + uint32_t* A, // 4 + uint32_t* B) { // 4 + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + } else { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(C[4]), + "r"(C[5]), + "r"(C[6]), + "r"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same::value) { // fp16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } else { // bf16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } + } else { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } + } +} + +template +__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t* s_u32 = (uint32_t*)(s); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1006648320), + "r"(1006648320), + "f"(d[0]), + "f"(d[1])); + } else { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1065369472), + "r"(1065369472), + "f"(d[0]), + "f"(d[1])); + } +} diff --git a/custom_ops/gpu_ops/append_attention/template_config.json b/custom_ops/gpu_ops/append_attention/template_config.json new file mode 100644 index 00000000000..cb1b678155f --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/template_config.json @@ -0,0 +1,44 @@ +{ + "multiquery_attention_c8": { + "name": "decode_append_attention_c8_kernel", + "function_name": "decode_append_attention_c8_kernel", + "impl_file": "decode_append_attention_c8_impl.cuh", + "template_params": [ + "T", + "CacheT", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_y", + "num_frags_z", + "is_scale_channel_wise", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "CacheT": ["uint8_t"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_y": [8], + "num_frags_z": [1], + "is_scale_channel_wise": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_append_attention_c", + "function_signature": "template __global__ void {function_name}{template_args}(const __grid_constant__ AttentionParams{params_template_args} params);\n\n" + } +} diff --git a/custom_ops/gpu_ops/append_attention/utils.cuh b/custom_ops/gpu_ops/append_attention/utils.cuh new file mode 100644 index 00000000000..536867eee16 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/utils.cuh @@ -0,0 +1,710 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include "helper.h" +#include "mem_util.cuh" + +#define NUM_WARPS_PER_BLOCK 4 +#define NUM_THREADS_PER_BLOCK 128 +#define kWarpSize 32 + +#define HOSTDEVICE __host__ __device__ + +/*-------------------------------------traits-----------------------------------------*/ +template +struct type_traits { + using paddle_type = T; + using phi_type = T; + using nv_type = T; + using nv2_type = T; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::float16; +// using nv_type = half; +// using nv2_type = half2; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::bfloat16; +// using nv_type = __nv_bfloat16; +// using nv2_type = __nv_bfloat162; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat16> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat162> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT8_E4M3FN; +// using phi_type = phi::dtype::float8_e4m3fn; +// using nv_type = __nv_fp8_e4m3; +// using nv2_type = __nv_fp8x2_e4m3; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8x2_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; +/*---------------------------------1. type + * traits--------------------------------------*/ + +/*---------------------------------2. fast + * convert--------------------------------------*/ +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { + printf("Do not support fp8 to half although it's very easy.\n"); +} + +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); + + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif +} + +inline __device__ static void convert_int8( + half* result, const uint32_t& source) { // 4 int8 each time + uint32_t* fp16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +} + +inline __device__ static void convert_int8( + __nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +} +/*---------------------------------2. fast + * convert--------------------------------------*/ + +/*---------------------------------3. vector + * cast--------------------------------------*/ +template +__forceinline__ HOSTDEVICE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(float* dst, + const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(half* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast( + float* dst, const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} +/*---------------------------------3. vector + * cast--------------------------------------*/ + +/*-------------------------------------4. + * func-----------------------------------------*/ +__forceinline__ HOSTDEVICE int div_up(int a, int b) { return (a + b - 1) / b; } + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) { + return (x > y) ? x - y : 0U; +} + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +HOSTDEVICE __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { + uint8_t eight_bits; + float quant_value; + if constexpr (is_need_kv_quant) { + quant_value = static_cast(scale * value); + } else { + quant_value = static_cast(value); + } + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + + if constexpr (IsFP8) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + quant_value = quant_value > 448.0f ? 448.0f : quant_value; + quant_value = quant_value < -448.0f ? -448.0f : quant_value; + auto tmp = static_cast<__nv_fp8_e4m3>(quant_value); + eight_bits = *(reinterpret_cast(&tmp)); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif + } else { + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + eight_bits = static_cast(quant_value + 128.0f); + } + return eight_bits; +} + +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { + if constexpr (IsFP8) { + convert_fp8(result, source); + } else { + convert_int8(result, source); + } +} + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +#define CHECK_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } + +__device__ __forceinline__ float2 fast_float2_mul(const float2& a, + const float2& b) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, 0.0;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, 0.0;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + return res; +} + +__device__ __forceinline__ float2 fast_float2_fma(float2& a, + const float2& b, + const float2& c) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, %6;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, %7;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), + "f"(a.y), + "f"(b.x), + "f"(b.y), + "f"(c.x), + "f"(c.y) // 输入操作数 + ); + return res; +} + +// __device__ __forceinline__ float2 fast_bfloat162_fma(__nv_bfloat162& a_bf162, +// const __nv_bfloat162& b_bf162, const __nv_bfloat162& c_bf162) { +// // 使用向量化PTX指令同时处理x/y分量 +// asm volatile ( +// "{\n" +// " fma.rn.b16 %0, %2, %4, %0;\n" // res.x = a.x * b.x +// " fma.rn.b16 %1, %3, %5, %1;\n" // res.y = a.y * b.y +// "}" +// : "=r"(a_bf162.x), "=r"(a_bf162.y) // 输出操作数 +// : "r"(b_bf162.x), "r"(b_bf162.y), +// "r"(c_bf162.x), "r"(c_bf162.y) // 输入操作数 +// ); +// float2 res = __bfloat1622float2_rn(a_bf162); +// return res; +// } + +__device__ __forceinline__ float2 fast_float2_sub_expf(const float2& a, + const float2& b) { + float2 res; + // 使用向量化减法指令(PTX sub.rn.f32) + asm volatile( + "{\n" + " sub.f32 %0, %2, %4;\n" // res.x = a.x - b.x + " sub.f32 %1, %3, %5;\n" // res.y = a.y - b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + res.x = expf(res.x); + res.y = expf(res.y); + return res; +} + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = static_cast(ori_out_vec[i]); + printf("Fatal! Unimplemented StoreFunc for cascade append attention\n"); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + 127.0f * + static_cast((ori_out_vec[i] + shift_bias_vec[i]) * + smooth_weight_vec[i]) * + in_scale; + quant_value = rintf(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + out_vec[i] = static_cast(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + quant_max_bound * static_cast(ori_out_vec[i]) * in_scale; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; + out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = ori_out_vec[i]; + } +}; +/*-------------------------------------4. + * func-----------------------------------------*/ + +/*-----------------------------------5. + * dispatch---------------------------------------*/ +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim"); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_Q_TILE_SIZE( \ + group_size, max_tokens_per_batch, Q_TILE_SIZE, ...) \ + if (group_size * max_tokens_per_batch <= 16) { \ + constexpr size_t Q_TILE_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + constexpr size_t Q_TILE_SIZE = 32; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \ + block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr size_t BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_IS_FP8(is_fp8, IS_FP8, ...) \ + if (is_fp8) { \ + constexpr bool IS_FP8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IS_FP8 = false; \ + __VA_ARGS__ \ + } + +struct AppendAttnMetaData { + int batch_size; + int block_size; + int q_num_heads; + int kv_num_heads; + int token_num; + int head_dims; + int head_dims_v; + int max_blocks_per_seq; + const int* mask_offset = nullptr; +}; + +template +struct AttentionParams { + T* __restrict__ qkv; + CacheT* __restrict__ cache_k; + CacheT* __restrict__ cache_v; + T* __restrict__ cache_k_scale; + T* __restrict__ cache_v_scale; + int* __restrict__ seq_lens_q; + int* __restrict__ seq_lens_kv; + int* __restrict__ block_indices; + int* __restrict__ num_blocks_ptr; + int* __restrict__ chunk_size_ptr; + int* __restrict__ cu_seqlens_q; + int* __restrict__ block_table; + int* __restrict__ mask_offset; + bool* __restrict__ attn_mask; + T* __restrict__ tmp_o; + float* __restrict__ tmp_m; + float* __restrict__ tmp_d; + int max_model_len; + int max_kv_len; + int max_blocks_per_seq; + float softmax_scale; + float quant_max_bound; + float quant_min_bound; + int num_blocks_x; + int attn_mask_len; + bool sliding_window; + int q_num_heads; + int kv_num_heads; + int max_num_chunks; + int max_tile_q; + int batch_size; + int token_num; + int head_dims; + int max_tokens_per_batch; +}; diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 109ed3fd8d8..bdb977ebc49 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -175,6 +175,83 @@ std::vector AppendAttentionWithOutput( const bool speculate_decoder, const int sliding_window); +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder); + +std::vector DecodeAppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window); + +void ConfigForAttention(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -1175,6 +1252,22 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("flash_mask_attention", &FlashAttentionMask, "flash_mask_attention"); + /** + * decoder_write_cache_with_rope.cu + * decoder_write_cache_with_rope + */ + m.def("decoder_write_cache_with_rope", + &DecoderWriteCacheWithRoPE, + "decoder write cache with RoPE function"); + + /** + * decode_append_attention.cu + * decode_append_attention + */ + m.def("decode_append_attention", + &DecodeAppendAttention, + "decoder append attention function"); + /** * gqa_rope_write_cache.cu * gqa_rope_write_cache @@ -1182,6 +1275,15 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("gqa_rope_write_cache", &GQARopeWriteCacheKernel, "gqa rope write cache function"); + + /** + * config_for_attention.cu + * config_for_attention + */ + m.def("config_for_attention", + &ConfigForAttention, + "config for attention function"); + /** * pre_cache_len_concat.cu * pre_cache_len_concat diff --git a/custom_ops/gpu_ops/decode_append_attention.cu b/custom_ops/gpu_ops/decode_append_attention.cu new file mode 100644 index 00000000000..3a9eaca2e8a --- /dev/null +++ b/custom_ops/gpu_ops/decode_append_attention.cu @@ -0,0 +1,344 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attention/decode_append_attention_c8_impl.cuh" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecodeAppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_num = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + // fmha_out generation, rewrite from AppendAttentionKernel + paddle::Tensor fmha_out = paddle::zeros( + {meta_data.token_num, meta_data.q_num_heads * meta_data.head_dims}, + qkv.dtype(), + qkv.place()); + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; + + auto stream = qkv.stream(); + bool is_fp8 = + cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8"; + bool is_dynamic_cfp8 = cache_quant_type == "block_wise_fp8"; + + if (max_just_dec_len_this_time > 0) { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, + max_tokens_per_batch, + Q_TILE_SIZE, + {DISPATCH_DyCfp8( + is_dynamic_cfp8, + IsDynamicC8, + {DISPATCH_IS_FP8(is_fp8, IsFP8, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeAppendC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeAppendC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})})})}) + } + return {fmha_out}; +} + +std::vector> DecodeAppendAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& block_indices_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& sinks_shape, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + const int token_num = qkv_shape[0]; + const int kv_num_heads = key_cache_shape[1]; + int head_dim = key_cache_shape[3]; + if (cache_quant_type == "cache_int4_zp") { + head_dim *= 2; + } + const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; + const int num_heads = total_num_head - 2 * kv_num_heads; + return {{token_num, num_heads * head_dim}}; +} + +std::vector DecodeAppendAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& block_indices_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& sinks_dtype, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(decode_append_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "block_indices", + "num_blocks", + "chunk_size", + "set_max_lengths", + paddle::Optional("attn_mask"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("mask_offset"), + paddle::Optional("sinks")}) + .Outputs({"fmha_out"}) + .Attrs({ + "cache_quant_type: std::string", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "max_tokens_per_batch: int", + "causal: bool", + "sliding_window: int", + }) + .SetKernelFn(PD_KERNEL(DecodeAppendAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(DecodeAppendAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecodeAppendAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu new file mode 100644 index 00000000000..7878e9926c5 --- /dev/null +++ b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu @@ -0,0 +1,326 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + auto stream = qkv.stream(); + + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + + if (max_just_dec_len_this_time > 0) { + if (speculate_decoder) { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } else { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } + } + return {qkv}; +} + +std::vector> DecoderWriteCacheWithRoPEInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_shape}; +} + +std::vector DecoderWriteCacheWithRoPEInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(decoder_write_cache_with_rope) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "set_max_lengths", + paddle::Optional("rotary_embs"), + paddle::Optional("qkv_bias"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) + .Outputs({"qkv_out"}) + .SetInplaceMap({{"qkv", "qkv_out"}}) + .Attrs({ + "rms_norm_eps: float", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "speculate_decoder: bool", + }) + .SetKernelFn(PD_KERNEL(DecoderWriteCacheWithRoPE)) + .SetInferShapeFn(PD_INFER_SHAPE(DecoderWriteCacheWithRoPEInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecoderWriteCacheWithRoPEInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 40900b18771..26bd49178c8 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -489,6 +489,13 @@ def find_end_files(directory, end_str): sources += find_end_files(fp8_auto_gen_directory, ".cu") if cc >= 90 and nvcc_version >= 12.0: + # decode attention + os.system( + "python utils/auto_gen_template_attention.py --config gpu_ops/append_attention/template_config.json --output gpu_ops/append_attention/template_instantiation/autogen" + ) + sources += ["gpu_ops/decode_append_attention.cu"] + sources += ["gpu_ops/decoder_write_cache_with_rope.cu"] + sources += find_end_files("gpu_ops/append_attention", ".cu") # Hopper optimized mla sources += find_end_files("gpu_ops/mla_attn", ".cu") sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] @@ -503,6 +510,7 @@ def find_end_files(directory, end_str): os.system("python gpu_ops/machete/generate.py") sources += find_end_files("gpu_ops/machete", ".cu") cc_compile_args += ["-DENABLE_MACHETE"] + nvcc_compile_args += ["--use_fast_math"] setup( name="fastdeploy_ops", diff --git a/custom_ops/utils/auto_gen_template_attention.py b/custom_ops/utils/auto_gen_template_attention.py new file mode 100644 index 00000000000..febec11b336 --- /dev/null +++ b/custom_ops/utils/auto_gen_template_attention.py @@ -0,0 +1,227 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Universal template instantiation generator - fully based on configuration file template instantiation generation.""" + +import argparse +import json +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class TemplateConfig: + """Template configuration class.""" + + name: str # Function name + function_name: str # Actual function name + impl_file: str # Implementation file path + template_params: List[str] # Template parameter list (in order) + dispatch_params: Dict[str, List[Any]] # Dispatch parameters + data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix) + max_instances_per_file: int = 60 # Maximum instances per file + file_prefix: str = "" # File prefix + function_signature: str = "" # Function signature template + + +class UniversalTemplateInstantiator: + """Universal template instantiator - fully based on configuration file.""" + + def __init__(self, config_file: str): + """Initialize the instantiator.""" + self.config_file = config_file + self.configs = self._load_configs() + + def _load_configs(self) -> Dict[str, TemplateConfig]: + """Load configuration file.""" + with open(self.config_file, "r", encoding="utf-8") as f: + config_data = json.load(f) + + configs = {} + for name, config_dict in config_data.items(): + config = TemplateConfig(**config_dict) + self._validate_config(config) + configs[name] = config + return configs + + def _validate_config(self, config: TemplateConfig): + """Validate configuration completeness.""" + for param_name in config.template_params: + if param_name not in config.dispatch_params: + raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params") + + def _build_template_args(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Build template arguments.""" + template_args_parts = [] + + for param_name in config.template_params: + if param_name in params: + template_args_parts.append(str(params[param_name])) + + else: + raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params") + + return f"<{', '.join(template_args_parts)}>" + + def _build_params_template_args(self, params: Dict[str, Any]) -> str: + """Build template arguments.""" + params_template_args = [] + # breakpoint() + if "T" in params: + params_template_args.append(str(params["T"])) + else: + raise ValueError("Template parameter 'T' not found in dispatch_params") + + if "CacheT" in params: + params_template_args.append(str(params["CacheT"])) + else: + raise ValueError("Template parameter 'CacheT' not found in dispatch_params") + + return f"<{', '.join(params_template_args)}>" + + def _generate_function_signature( + self, config: TemplateConfig, template_args: str, params_template_args: str + ) -> str: + """Generate function signature.""" + if config.function_signature: + signature = config.function_signature.format( + function_name=config.function_name, + template_args=template_args, + params_template_args=params_template_args, + ) + + return signature + else: + raise ValueError(f"Function signature not found for {config.name}") + + def _generate_file_header(self, config: TemplateConfig) -> str: + """Generate file header.""" + return f"""// Generated by autogen_template_instantiation.py - Do not edit. + +#pragma once + +#include "../../{config.impl_file}" +""" + + def _generate_template_instantiation(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Generate template instantiation.""" + template_args = self._build_template_args(config, params) + params_template_args = self._build_params_template_args(params) + return self._generate_function_signature(config, template_args, params_template_args) + + def _clean_output_directory(self, output_dir: str): + """Clean output directory before generating new files.""" + output_path = Path(output_dir) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + def generate_combinations_for_type(self, config: TemplateConfig) -> List[Dict[str, Any]]: + """Generate parameter combinations for specific type.""" + combinations = [] + + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return + + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + + return combinations + + def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: + """Split combinations into multiple files.""" + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks + + def generate_file_content( + self, + config: TemplateConfig, + file_index: int, + combinations: List[Dict[str, Any]], + ) -> str: + """Generate file content.""" + content = self._generate_file_header(config) + + for params in combinations: + content += self._generate_template_instantiation(config, params) + + return content + + def generate_for_function_type(self, function_name: str, output_dir: str): + """Generate template instantiation files for specific function type.""" + if function_name not in self.configs: + raise ValueError(f"Function type '{function_name}' not found in config") + + config = self.configs[function_name] + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + combinations = self.generate_combinations_for_type(config) + if combinations: + chunks = self.split_combinations(combinations, config.max_instances_per_file) + for i, chunk in enumerate(chunks): + filename = f"{config.file_prefix}_part_{i:02d}.cu" + filepath = output_path / filename + content = self.generate_file_content(config, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + def generate_all(self, output_dir: str): + """Generate all configured function types.""" + self._clean_output_directory(output_dir) + for function_name in self.configs.keys(): + print(f"Generating template instantiations for {function_name}...") + self.generate_for_function_type(function_name, output_dir) + print(f"Completed generating {function_name} template instantiations.") + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Universal template instantiation generator") + parser.add_argument( + "--config", + "-c", + type=str, + help="Configuration file path (JSON format)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + help="Output directory", + ) + + args = parser.parse_args() + + try: + instantiator = UniversalTemplateInstantiator(args.config) + instantiator.generate_all(args.output) + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 1ae0ef361de..c3972e7b2d1 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -17,6 +17,7 @@ from .attention_selecter import get_attention_backend from .base_attention_backend import AttentionBackend from .block_multihead_attn_backend import BlockAttentionBackend +from .decode_append_attention_backend import DecodeAppendAttentionBackend from .flash_attn_backend import FlashAttentionBackend from .flash_mask_attn_backend import FlashMaskAttentionBackend from .iluvatar_attn_backend import IluvatarAttnBackend @@ -30,6 +31,7 @@ "PaddleNativeAttnBackend", "get_attention_backend", "AppendAttentionBackend", + "DecodeAppendAttentionBackend", "XPUAttentionBackend", "MLAAttentionBackend", "FlashAttentionBackend", diff --git a/fastdeploy/model_executor/layers/attention/decode_append_attention_backend.py b/fastdeploy/model_executor/layers/attention/decode_append_attention_backend.py new file mode 100644 index 00000000000..522f73c69fe --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/decode_append_attention_backend.py @@ -0,0 +1,319 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_append_attention, + decoder_write_cache_with_rope, + init_kv_signal_per_query, + init_signal_layerwise, + open_shm_and_get_meta_signal, +) + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id + + +@dataclass +class DecodeAppendAttentionMetadata(AttentionMetadata): + """ + AppendAttentionMetadata + """ + + _dtype: paddle.dtype = paddle.bfloat16 + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) + + +class DecodeAppendAttentionBackend(AttentionBackend): + """ + AppendAttentionBackend backend implementation. + """ + + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: DecodeAppendAttentionMetadata + + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ) -> None: + """ + AppendAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: DecodeAppendAttentionMetadata = None + self.block_size: int = fd_config.cache_config.block_size + self.max_seq_len: int = fd_config.model_config.max_model_len + self.rope_theta: float = ( + 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta + ) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( + fd_config.model_config, "use_3d_rope", False + ) + if fd_config.speculative_config.model_type != "main": + self.rope_3d = False + self.causal: bool = getattr(fd_config.model_config, "causal", True) + self.speculative_method: str = fd_config.speculative_config.method + self.speculate_max_draft_token_num: int = ( + fd_config.speculative_config.num_speculative_tokens if self.speculative_method is not None else 0 + ) + self.max_tokens_per_batch = self.speculate_max_draft_token_num + 1 + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + + self.kv_num_heads: int = kv_num_heads + self.num_heads: int = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads + self.head_dim: int = fd_config.model_config.head_dim + + self.num_layers: int = fd_config.model_config.num_hidden_layers + + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + + self.start_layer_index: int = fd_config.model_config.start_layer_index + + if fd_config.parallel_config.expert_parallel_rank is None: + fd_config.parallel_config.expert_parallel_rank = 0 + + self.rank, self.device_id = init_rank_and_device_id(fd_config) + self.use_output = not fd_config.graph_opt_config.full_cuda_graph + self.fd_config = fd_config + self.buffer: dict = {} + + def init_buffer( + self, + max_batch_size: int, + ) -> dict: + # Initialize AttentionBackend buffers + assert self.num_heads % self.kv_num_heads == 0 + assert self.max_seq_len % self.block_size == 0 + + min_chunk_size = 128 + max_num_chunk = (self.max_seq_len + min_chunk_size - 1) // min_chunk_size + + q_tile_size = 16 if self.max_tokens_per_batch * self.group_size <= 16 else 32 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + self.buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + # block_indices: Launched block's indices with 4 dimensions [batch_idx, kv_head_idx, chunk_idx, q_tile_idx] in decode append attention backend + self.buffer["block_indices"] = paddle.full( + [max_batch_size * self.kv_num_heads * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + # num_blocks: Number of Launched blocks in decode append attention backend, researched by config_for_attention op + self.buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + # chunk_size: Chunk size for split kv cache in decode append attention backend, researched by config_for_attention op + self.buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + # tmp_workspace: Workspace tensor for temporary store the result before merging in decode append attention backend + self.buffer["tmp_workspace"] = paddle.full( + [max_batch_size * self.max_tokens_per_batch, max_num_chunk, self.num_heads * self.head_dim], + 0, + dtype=paddle.get_default_dtype(), + ) + # tmp_m: Tmp_m tensor for temporary store the max value before merging in decode append attention backend + self.buffer["tmp_m"] = paddle.full( + [max_batch_size * self.max_tokens_per_batch, max_num_chunk, self.num_heads], 0, dtype="float32" + ) + # tmp_d: Tmp_d tensor for temporary store the exponential sum before merging in decode append attention backend + self.buffer["tmp_d"] = paddle.full( + [max_batch_size * self.max_tokens_per_batch, max_num_chunk, self.num_heads], 0, dtype="float32" + ) + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + metadata = DecodeAppendAttentionMetadata() + metadata._dtype = paddle.get_default_dtype() + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + if self.pd_disaggregation_mode == "per_chunk": + if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + init_kv_signal_per_query( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + self.rank, + self.num_layers + self.num_layers_draft_model, + ) + elif self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_metadata = open_shm_and_get_meta_signal( + self.rank, int(self.device_id), self.keep_pd_step_flag + ) + + self.attention_metadata: AttentionMetadata = metadata + + def get_attntion_meta(self) -> AttentionMetadata: + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ): + """ + Calculate kv cache shape + """ + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + key_cache_shape[-1] = self.head_dim // 2 + value_cache_shape = key_cache_shape + return key_cache_shape, value_cache_shape + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + forward_mixed + """ + metadata = self.attention_metadata + sliding_window = layer.sliding_window + + if self.rope_3d: + assert len(forward_meta.rotary_embs.shape) == 6 + else: + assert len(forward_meta.rotary_embs.shape) == 5 + if layer.use_neox_rotary_style: + assert forward_meta.rotary_embs.shape[0:4] == [2, 1, self.max_seq_len, 1] + # 128 is qwen3 + # 32 is glm + assert forward_meta.rotary_embs.shape[4] in [128, 32] + + if self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index, + ) + cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") + if cache_quant_type_str == "block_wise_fp8": + cache_k = forward_meta.caches[4 * layer.layer_id] + cache_v = forward_meta.caches[4 * layer.layer_id + 1] + cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] + cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] + else: + cache_k = forward_meta.caches[2 * layer.layer_id] + cache_v = forward_meta.caches[2 * layer.layer_id + 1] + cache_k_scales = getattr(layer, "cache_k_scale", None) + cache_v_scales = getattr(layer, "cache_v_scale", None) + + if layer.layer_id == 0: + config_for_attention( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + getattr(layer, "cache_quant_type_str", "none"), + self.group_size, + self.kv_num_heads, + self.max_tokens_per_batch, + ) + qkv_out = decoder_write_cache_with_rope( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + self.buffer["max_len_tensor_cpu"], + forward_meta.rotary_embs, + getattr(layer, "qkv_bias", None), + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + getattr(layer, "q_norm_weight", None), + getattr(layer, "k_norm_weight", None), + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculative_method is not None, + ) + res = decode_append_attention( + qkv_out, + cache_k, + cache_v, + self.buffer["tmp_workspace"], + self.buffer["tmp_m"], + self.buffer["tmp_d"], + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + forward_meta.attn_mask, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + forward_meta.attn_mask_offsets, + getattr(layer, "sinks", None), + getattr(layer, "cache_quant_type_str", "none"), + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.max_tokens_per_batch, + self.causal, + sliding_window, + ) + return res diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index 064155d2ccc..9bfa37378f1 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,9 @@ """ from .append_attention import append_attention, append_attention_with_output +from .config_for_attention import config_for_attention +from .decode_append_attention import decode_append_attention +from .decoder_write_cache_with_rope import decoder_write_cache_with_rope from .flash_mask_attention import flash_mask_attention from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block from .gqa_rope_write_cache import gqa_rope_write_cache @@ -33,4 +36,7 @@ "pre_cache_len_concat", "init_kv_signal_per_query", "flash_mask_attention", + "config_for_attention", + "decoder_write_cache_with_rope", + "decode_append_attention", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py new file mode 100644 index 00000000000..d8226aad4b1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py @@ -0,0 +1,58 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + config_for_attention as config_for_attention_cuda, + ) + + +def config_for_attention( + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + max_len_tensor_cpu: paddle.Tensor, + cache_quant_type: str = "none", + group_size: int = 1, + kv_num_heads: int = 1, + max_tokens_per_batch: int = 1, +): + """ + append_attention + """ + if current_platform.is_cuda(): + config_for_attention_cuda( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_indices, + num_blocks, + chunk_size, + max_len_tensor_cpu, + cache_quant_type, + group_size, + kv_num_heads, + max_tokens_per_batch, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decode_append_attention.py b/fastdeploy/model_executor/layers/attention/ops/decode_append_attention.py new file mode 100644 index 00000000000..36999cdfb25 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decode_append_attention.py @@ -0,0 +1,103 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decode_append_attention as decode_append_attention_cuda, + ) + + +def decode_append_attention( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + set_max_lengths: paddle.Tensor, + attn_mask: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + cache_quant_type: str = "none", + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + max_tokens_per_batch: int = 1, + causal: bool = True, + sliding_window: int = 0, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + out = decode_append_attention_cuda( + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + set_max_lengths, + attn_mask, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + mask_offset, + sinks, + cache_quant_type, + max_input_length, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + causal, + sliding_window, + ) + return out + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py new file mode 100644 index 00000000000..b10f6cd1bf6 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py @@ -0,0 +1,97 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decoder_write_cache_with_rope as decoder_write_cache_with_rope_cuda, + ) + + +def decoder_write_cache_with_rope( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + set_max_lengths: paddle.Tensor, + rotary_embs: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + speculate_decoder: bool = False, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + qkv_out = decoder_write_cache_with_rope_cuda( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + set_max_lengths, + rotary_embs, + qkv_bias, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + speculate_decoder, + ) + return qkv_out + else: + raise NotImplementedError diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 9db9ebf77ba..c3c1bd7ef24 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -29,6 +29,7 @@ class _Backend(enum.Enum): PLAS_ATTN = enum.auto() HPU_ATTN = enum.auto() FLASH_MASK_ATTN = enum.auto() + DECODE_APPEND_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index 8d0d559fe38..2b2ce854c26 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -58,6 +58,9 @@ def get_attention_backend_cls(cls, selected_backend: _Backend): elif selected_backend == _Backend.APPEND_ATTN: logger.info("Using APPEND ATTN backend.") return "fastdeploy.model_executor.layers.attention.AppendAttentionBackend" + elif selected_backend == _Backend.DECODE_APPEND_ATTN: + logger.info("Using DECODE APPEND ATTN backend.") + return "fastdeploy.model_executor.layers.attention.DecodeAppendAttentionBackend" elif selected_backend == _Backend.MLA_ATTN: logger.info("Using MLA ATTN backend.") return "fastdeploy.model_executor.layers.attention.MLAAttentionBackend" diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 976b1852760..cca6f5e7945 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -281,6 +281,8 @@ def _initialize_attn_backend( encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, ) + if envs.FD_ATTENTION_BACKEND == "DECODE_APPEND_ATTN": + attn_backend.init_buffer(self.scheduler_config.max_num_seqs) if attn_backend is None: raise NotImplementedError( "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9dc3e9c534f..5db65db9732 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1548,7 +1548,8 @@ def _initialize_attn_backend(self) -> None: encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, ) - + if envs.FD_ATTENTION_BACKEND == "DECODE_APPEND_ATTN": + attn_backend.init_buffer(self.scheduler_config.max_num_seqs) self.attn_backends.append(attn_backend) def _dummy_pooler_run_task( diff --git a/tests/ce/server/test_logprobs.py b/tests/ce/server/test_logprobs.py index 83ca89486c9..9bc858532f1 100644 --- a/tests/ce/server/test_logprobs.py +++ b/tests/ce/server/test_logprobs.py @@ -25,10 +25,10 @@ def test_unstream_with_logprobs(): # 校验返回内容与概率信息 assert resp_json["choices"][0]["message"]["content"] == "牛顿的" assert resp_json["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" - assert resp_json["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.031025361269712448 + assert resp_json["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.03135016933083534 assert resp_json["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == { "token": "牛顿", - "logprob": -0.031025361269712448, + "logprob": -0.03135016933083534, "bytes": [231, 137, 155, 233, 161, 191], "top_logprobs": None, } @@ -102,10 +102,10 @@ def test_stream_with_logprobs(): # 校验概率字段 assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" - assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.031025361269712448 + assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.03135016933083534 assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == { "token": "牛顿", - "logprob": -0.031025361269712448, + "logprob": -0.03135016933083534, "bytes": [231, 137, 155, 233, 161, 191], } @@ -187,10 +187,10 @@ def test_stream_with_temp_scaled_logprobs(): # 校验概率字段 assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" - assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.006811376195400953 + assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.006874244660139084 assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == { "token": "牛顿", - "logprob": -0.006811376195400953, + "logprob": -0.006874244660139084, "bytes": [231, 137, 155, 233, 161, 191], } diff --git a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py index e51018f201e..331ba78054e 100644 --- a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py +++ b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py @@ -205,7 +205,7 @@ def test_consistency_between_runs(api_url, headers, consistent_payload): # base result base_path = os.getenv("MODEL_PATH") if base_path: - base_file = os.path.join(base_path, "ernie-4_5-vl-base-tp2-dev") + base_file = os.path.join(base_path, "ernie-4_5-vl-base-tp2-1131") else: base_file = "ernie-4_5-vl-base-tp2-dev" with open(base_file, "r") as f: diff --git a/tests/e2e/test_EB_VL_Lite_serving.py b/tests/e2e/test_EB_VL_Lite_serving.py index f93f355a754..bb9e213e8ce 100644 --- a/tests/e2e/test_EB_VL_Lite_serving.py +++ b/tests/e2e/test_EB_VL_Lite_serving.py @@ -204,7 +204,7 @@ def test_consistency_between_runs(api_url, headers, consistent_payload): # base result base_path = os.getenv("MODEL_PATH") if base_path: - base_file = os.path.join(base_path, "ernie-4_5-vl-base-tp2-dev") + base_file = os.path.join(base_path, "ernie-4_5-vl-base-tp2-1131") else: base_file = "ernie-4_5-vl-base-tp2-dev" with open(base_file, "r") as f: diff --git a/tests/model_loader/test_torch_model.py b/tests/model_loader/test_torch_model.py index bc8252a4427..69843a7b999 100644 --- a/tests/model_loader/test_torch_model.py +++ b/tests/model_loader/test_torch_model.py @@ -140,7 +140,7 @@ def test_model_against_baseline( # Get baseline suffix from config model_config = hugging_face_model_param_map.get(model_name_or_path, {}) - baseline_suffix = model_config.get("baseline_suffix", "tp2") + baseline_suffix = model_config.get("baseline_suffix", "tp2_1131") baseline_filename = f"{model_name_or_path}-{baseline_suffix}" if base_path: diff --git a/tests/operators/attention/test_decode_attention.py b/tests/operators/attention/test_decode_attention.py new file mode 100644 index 00000000000..51938b459af --- /dev/null +++ b/tests/operators/attention/test_decode_attention.py @@ -0,0 +1,875 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle +from paddle.incubate.nn.functional import fused_rms_norm + +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_append_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + pre_cache_len_concat, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_neox_style_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, 1, D] + emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim)) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D/2] + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + # shape: [B, S, 1, D/2] + emb = paddle.unsqueeze(emb, 2) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, cache_len): + # sin [sequence_length, embed_size_per_head//2] + # cos [sequence_length, embed_size_per_head//2] + # sin, cos = paddle.chunk(rp, 2, axis=-1) + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = cos[:, :, cache_len : cache_len + seq, ...] + sin = sin[:, :, cache_len : cache_len + seq, ...] + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + + if self.use_neox_rotary_style: + sin_pos = sin + cos_pos = cos + # NeoX Stype:前后半部分分块旋转 + rotate_half_q = paddle.reshape( + paddle.concat( + [ + -q[:, :, :, q.shape[-1] // 2 :], + q[:, :, :, : q.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.concat( + [ + -k[:, :, :, k.shape[-1] // 2 :], + k[:, :, :, : k.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(k), + ) + else: + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + # GPT Stype:奇偶位置分块旋转 + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def create_attn_mask(mask_type, batch_size, seq_lens, pre_cache_length=0, sliding_window=0): + max_seq_len = max(seq_lens) + mask = paddle.zeros( + # [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length], + [batch_size, 1, max_seq_len, max_seq_len], + dtype=mask_type, + ) + mask[:, :, :, :pre_cache_length] = 1 + for i in range(batch_size): + seq_len = seq_lens[i] + ones_tensor = paddle.ones(shape=(seq_len, seq_len), dtype=mask_type) + if sliding_window <= 0: + mask[i, 0, :seq_len, :seq_len] = (paddle.tril(ones_tensor) - 1) * 1e4 + else: + tmp_triu = paddle.triu(ones_tensor, -(sliding_window - 1)) + mask[i, 0, :seq_len, :seq_len] = (paddle.tril(ones_tensor) * tmp_triu - 1) * 1e4 + return mask + + +def naive_attention_impl( + query, + key, + value, + pre_key=None, + pre_value=None, + mask=None, + scale=1.0, + cache_k_dequant_scales=None, + cache_v_dequant_scales=None, + use_cachekv_int8="None", + q_norm_weight=None, + k_norm_weight=None, + sinks=None, +): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if pre_key is not None: + pre_key = pre_key.reshape([batch, kv_head, 1, -1, head_dim]) + pre_key = paddle.tile(pre_key, [1, 1, heads // kv_head, 1, 1]) + pre_key = pre_key.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([pre_key, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if pre_value is not None: + pre_value = pre_value.reshape([batch, kv_head, 1, -1, head_dim]) + pre_value = paddle.tile(pre_value, [1, 1, heads // kv_head, 1, 1]) + pre_value = pre_value.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([pre_value, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + + if sinks is not None: + kv_len = attention.shape[-1] + sinks_tiled = sinks.unsqueeze([0, 2, 3]).expand([batch, heads, seq_len, 1]) + attention = paddle.concat([attention, sinks_tiled], axis=-1) + softmax_result = paddle.nn.functional.softmax(attention, -1)[:, :, :, :kv_len] + else: + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, head_dim = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * head_dim], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) / 10.0 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) / 10.0 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) / 10.0 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeAppendAttention(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + def init_tensor(self): + # seq_lens + if self.seq_len_dec is None: + self.seq_lens_dec = [ + self.cache_len, + ] * self.batch_size + else: + self.batch_size = len(self.seq_lens_dec) + self.seq_lens_decoder = paddle.to_tensor( + self.seq_lens_dec, + "int32", + ) + if self.seq_lens_this_time is None: + self.seq_lens_this_time = [ + self.max_tokens_per_batch, + ] * self.batch_size + self.token_num = sum(self.seq_lens_this_time) + self.seq_lens_this_time = paddle.to_tensor(self.seq_lens_this_time, "int32") + + self.seq_lens_enc = [0] * self.batch_size + + self.seq_lens_encoder = paddle.to_tensor( + self.seq_lens_enc, + "int32", + ) + + # self.qkv = paddle.rand([self.token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim], dtype=self.dtype) + self.q, self.k, self.v, self.qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + self.qkv = paddle.to_tensor(self.qkv, dtype=self.dtype) + + # qk_norm + self.q_norm_weight = None + self.k_norm_weight = None + if self.use_qk_norm: + q_norm_weight_np = np.random.random([self.head_dim]) / 10 + k_norm_weight_np = np.random.random([self.head_dim]) / 10 + self.q_norm_weight = paddle.to_tensor(q_norm_weight_np, dtype="float32") + self.k_norm_weight = paddle.to_tensor(k_norm_weight_np, dtype="float32") + + # rotary embedding + self.rope = RopeEmbedding(False) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache_kv && scale + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + + if self.use_dynamic_quant: + self.cache_scale_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_k_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_v_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_k_out_scale = None + self.cache_k_out_scale = None + else: + self.cache_k_scale = self.quant_max_bound / self.k.transpose([1, 0, 2, 3]).reshape( + [self.kv_num_head, -1] + ).max(axis=1) + self.cache_v_scale = self.quant_max_bound / self.v.transpose([1, 0, 2, 3]).reshape( + [self.kv_num_head, -1] + ).max(axis=1) + + self.cache_k_out_scale = ( + self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).max(axis=1) / self.quant_max_bound + ) + self.cache_v_out_scale = ( + self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).max(axis=1) / self.quant_max_bound + ) + + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + + ( + self.batch_id_per_token, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, self.seq_lens_this_time) + + # mask + if self.mask_matrix: + self.attn_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.max_tokens_per_batch, + ] + * self.batch_size, + sliding_window=self.sliding_window, + ) + else: + self.attn_mask = None + + # mask offset + self.mask_offset = None + if self.use_mask_offset: + self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32") + for i in range(self.batch_size): + self.mask_offset[i * 2] = 0 + self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1 + print("decoder mask_offset: ", self.mask_offset) + + if self.use_sinks: + self.sinks = paddle.to_tensor( + np.random.random([self.q_num_head]), place=self.place, dtype=self.dtype, stop_gradient=False + ) + else: + self.sinks = None + + # buffer + self.buffer = {} + min_chunk_size = 128 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + self.group_size = self.q_num_head // self.kv_num_head + q_tile_size = 16 if self.max_tokens_per_batch * self.group_size <= 16 else 32 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + self.buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + # block_indices: Launched block's indices with 4 dimensions [batch_idx, kv_head_idx, chunk_idx, q_tile_idx] in decode append attention backend + self.buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + # num_blocks: Number of Launched blocks in decode append attention backend, researched by config_for_attention op + self.buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + # chunk_size: Chunk size for split kv cache in decode append attention backend, researched by config_for_attention op + self.buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + # tmp_workspace: Workspace tensor for temporary store the result before merging in decode append attention backend + self.buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + # tmp_m: Tmp_m tensor for temporary store the max value before merging in decode append attention backend + self.buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + # tmp_d: Tmp_d tensor for temporary store the exponential sum before merging in decode append attention backend + self.buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + + def apply_qk_norm(self, head_dim, dtype, q, k): + bs, q_num_head, seq_len, head_dim = q.shape + _, kv_num_head, _, _ = k.shape + + q = q.reshape([-1, head_dim]) + k = k.reshape([-1, head_dim]) + q = fused_rms_norm(q.astype("float32"), self.q_norm_weight, None, self.rms_norm_eps)[0].astype(dtype) + k = fused_rms_norm(k.astype("float32"), self.k_norm_weight, None, self.rms_norm_eps)[0].astype(dtype) + q = q.reshape([-1, q_num_head, seq_len, head_dim]) + k = k.reshape([-1, kv_num_head, seq_len, head_dim]) + return q, k + + def naive_attention(self, pre_k, pre_v): + q, k = self.rope._apply_rope(self.rotary_embs, self.q, self.k, self.cache_len) + if self.use_qk_norm: + q, k = self.apply_qk_norm(self.head_dim, self.dtype, q, k) + + out_ref = naive_attention_impl( + q, + k, + self.v, + pre_k, + pre_v, + self.attn_mask, + self.softmax_scale, + sinks=self.sinks, + ) + out_ref = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_ref, self.token_num) + return q, k, self.v, out_ref + + def decode_attention(self): + paddle.disable_static() + + config_for_attention( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + + decoder_write_cache_with_rope( + self.qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["max_len_tensor_cpu"], + self.rotary_embs, # rotary_embs + None, # qkv_bias + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + self.q_norm_weight, # q_norm_weight + self.k_norm_weight, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + out = decode_append_attention( + self.qkv, + self.cache_k, + self.cache_v, + self.buffer["tmp_workspace"], + self.buffer["tmp_m"], + self.buffer["tmp_d"], + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], # rope_emb + None, # attn_mask + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + self.mask_offset, # mask_offset + self.sinks, # sinks + self.cache_quant_type, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + self.sliding_window, + ) + return self.qkv, out + # np.testing.assert_allclose( + # out.numpy(), + # out_.numpy(), + # rtol=1e-02, + # atol=1e-02, + # ) + + def prefill(self): + # init seq_len + seq_lens_encoder = copy.deepcopy(self.seq_lens_decoder) + seq_lens_decoder = paddle.zeros([self.batch_size], dtype="int32") + seq_lens_this_time = seq_lens_encoder + token_num = seq_lens_this_time.sum().item() + qkv_np = np.random.random([token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim]) / 10.0 + qkv = paddle.to_tensor(qkv_np, dtype=self.dtype) + + ( + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(self.batch_size, seq_lens_this_time) + # buffer + decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 16) + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + ( + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + kv_token_num_cpu, + ) = pre_cache_len_concat( + seq_lens_decoder, + seq_lens_this_time, + max_len_tensor_cpu[2], + self.block_size, + ) + q, k, v, _ = gqa_rope_write_cache( + qkv, + self.cache_k, + self.cache_v, + cu_seqlens_q, + cu_seqlens_k, + self.rotary_embs, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + self.block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + self.q_norm_weight, + self.k_norm_weight, + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + kv_token_num_cpu[0].item(), + self.max_model_len, + self.rms_norm_eps, + self.cache_quant_type, + self.rope_3d, + ) + + k = k.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + v = v.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + return k, v + + def test_all(self): + pre_k, pre_v = self.prefill() + + q_ref, k_ref, v_ref, out_ref = self.naive_attention(pre_k, pre_v) + qkv_out, out = self.decode_attention() + + np.testing.assert_allclose( + out.astype("float32").numpy(), + out_ref.astype("float32").numpy(), + rtol=1e-03, + atol=1e-03, + ) + + +class TestDecodeAppendAttentionMultiBatch(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionSpeculate(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionMultiHead(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionMultiSpeculate(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionQKNorm(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = True + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +if __name__ == "__main__": + unittest.main()