Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 213 additions & 21 deletions custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "utils.cuh"

template <int THREADBLOCK_SIZE>
__global__ void
Expand Down Expand Up @@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
max_len_tensor.data<int>(), batch_size);
}

template <uint32_t config_size>
__global__ void search_chunk_size_for_mla(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ num_blocks_x,
int *__restrict__ res_chunk_size,
const int bsz,
const int set_chunk_size,
const int block_size,
const int sm_cout) {
const uint32_t conf_id = threadIdx.x;
int gridx = 0;
if (set_chunk_size > 0 && conf_id == 0) {
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;

int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
gridx += loop_times;
}
*num_blocks_x = gridx;
*res_chunk_size = set_chunk_size;
} else if (conf_id < config_size) {
__shared__ int gridx_shared[config_size];
// chunk_size is a multiple of 64
const int chunk_size = block_size << conf_id;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;

int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
gridx += loop_times;
}
gridx_shared[conf_id] = gridx;
__syncthreads();
if (threadIdx.x == 0) {
uint32_t res_id = 0;
uint32_t max_last_wave_block = 0;
for (uint32_t i = 1; i < config_size; i++) {
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
if (last_wave_block >= max_last_wave_block) {
res_id = i;
max_last_wave_block = last_wave_block;
}
}
*num_blocks_x = gridx_shared[res_id];
*res_chunk_size = block_size << res_id;
}
}
}

__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
const int bsz,
const int chunk_size) {
if (threadIdx.x == 0) {
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;

if (seq_len == 0) continue;

int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
}
}
}

__global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
Expand Down Expand Up @@ -230,6 +318,9 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
paddle::Tensor decoder_num_blocks_x;
paddle::Tensor decoder_chunk_size_device;
paddle::Tensor decoder_chunk_size_cpu; /*cpu*/

auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
Expand All @@ -239,6 +330,103 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(

max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);

// decoder
if (max_dec_len_this_time > 0) {
const bool mla_use_tensorcore = GetMlaUseTensorcore();
if (mla_use_tensorcore && group_size <= 64) {
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
decoder_chunk_size_device = GetEmptyTensor(
{1}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x = GetEmptyTensor(
{1}, paddle::DataType::INT32, seq_lens_encoder.place());

int device;
cudaGetDevice(&device);
int sm_cout;
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
constexpr int config_size =
12; // search space for chunk size:[64, 128, 256, ... 131072]

search_chunk_size_for_mla<config_size>
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_num_blocks_x.data<int>(),
decoder_chunk_size_device.data<int>(),
bsz,
set_chunk_size,
block_size,
sm_cout);

decoder_chunk_size_cpu =
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];

const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz *
decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(),
0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(),
0, sizeof(int32_t), stream));

std::cout << "-----------------------------------------------------------"
<< std::endl;
std::cout << "chunk size1:================================ " << chunk_size
<< std::endl;
std::cout << "-----------------------------------------------------------"
<< std::endl;
split_block_for_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
bsz,
chunk_size);
decoder_num_blocks_x_cpu.copy_(
decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);

} else {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids = GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_num_blocks_x = GetEmptyTensor(
{1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu.copy_(
decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);

decoder_chunk_size_cpu = paddle::full(
{1}, 131072, paddle::DataType::INT32, paddle::CPUPlace());
}
} else {
decoder_chunk_size_cpu =
paddle::full({1}, 131072, paddle::DataType::INT32, paddle::CPUPlace());
decoder_num_blocks_x = paddle::full(
{1}, -1, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu.copy_(
decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}

// encoder
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
Expand Down Expand Up @@ -292,27 +480,27 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}

if (max_just_dec_len_this_time > 0) {
// Clear buffer
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
// if (max_just_dec_len_this_time > 0) {
// // Clear buffer
// const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
// const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
// PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
// PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
// PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));

auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}
// auto decoder_num_blocks_x =
// GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
// split_q_block<<<1, 32, 0, stream>>>(
// seq_lens_this_time.data<int>(),
// seq_lens_encoder.data<int>(),
// decoder_batch_ids.data<int>(),
// decoder_tile_ids_per_batch.data<int>(),
// decoder_num_blocks_x.data<int>(),
// bsz,
// decoder_block_shape_q,
// group_size);
// decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
// }

return {
encoder_batch_ids,
Expand All @@ -321,6 +509,8 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
decoder_num_blocks_x,
decoder_chunk_size_cpu, /*cpu*/
max_len_kv_cpu, /*cpu*/
};
}
Expand All @@ -342,6 +532,8 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks_x_cpu"),
paddle::Optional("decoder_num_blocks_x"),
paddle::Optional("decoder_chunk_size_cpu"),
"max_len_kv_cpu"
})
.Attrs({
Expand Down
1 change: 1 addition & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& decoder_chunk_size_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
Expand Down
9 changes: 9 additions & 0 deletions custom_ops/gpu_ops/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ inline bool get_mla_use_tensorcore() {
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}
inline int get_mla_dec_chunk_size(int bsz) {
static const char* mla_dec_chunk_size_env =
std::getenv("FLAGS_mla_dec_chunk_size");
static const int mla_dec_chunk_size =
mla_dec_chunk_size_env == nullptr
? -1
: std::stoi(std::string(mla_dec_chunk_size_env));
return bsz > 1 ? mla_dec_chunk_size : 64;
}
8 changes: 8 additions & 0 deletions custom_ops/gpu_ops/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,11 @@ inline int GetSMVersion() {
return sm_version;

}

inline bool GetMlaUseTensorcore() {
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
const bool mla_use_tensorcore =
flags_mla_use_tensorcore && enable_mla_tensorcore;
return mla_use_tensorcore;
}
5 changes: 4 additions & 1 deletion custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int chunk_size,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
Expand All @@ -97,7 +98,7 @@ void BatchMLAWithPagedKVCacheKernel(
const auto q_head_num = meta_data.q_num_heads;
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
const auto max_block_num = bsz * max_block_num_per_seq;
const uint32_t chunk_size = get_max_partition_size(bsz);
// const uint32_t chunk_size = get_max_partition_size(bsz);


int q_head_dim = meta_data.head_dims;
Expand Down Expand Up @@ -185,6 +186,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int chunk_size,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
Expand Down Expand Up @@ -219,6 +221,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int chunk_size,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int chunk_size,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
Expand Down
Loading
Loading