Skip to content

Commit 385fe6d

Browse files
[Others] clean code (#5133)
1 parent 7ac2593 commit 385fe6d

File tree

3 files changed

+51
-57
lines changed

3 files changed

+51
-57
lines changed

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -311,16 +311,17 @@ void GetBlockShapeAndSplitKVBlock(
311311
if (mla_backend && group_size <= 64) {
312312
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
313313

314-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
314+
CUDA_CHECK(cudaMemsetAsync(
315315
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
316316

317-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
317+
CUDA_CHECK(cudaMemsetAsync(
318318
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
319319

320320
int device;
321-
cudaGetDevice(&device);
321+
CUDA_CHECK(cudaGetDevice(&device));
322322
int sm_cout;
323-
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
323+
CUDA_CHECK(cudaDeviceGetAttribute(
324+
&sm_cout, cudaDevAttrMultiProcessorCount, device));
324325
constexpr int config_size =
325326
12; // search space for chunk size:[64, 128, 256, ... 131072]
326327

@@ -341,16 +342,14 @@ void GetBlockShapeAndSplitKVBlock(
341342
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
342343
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
343344

344-
PADDLE_ENFORCE_GPU_SUCCESS(
345-
cudaMemsetAsync(decoder_batch_ids.data<int>(),
346-
0,
347-
decoder_batch_ele_num * sizeof(int32_t),
348-
stream));
349-
PADDLE_ENFORCE_GPU_SUCCESS(
350-
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
351-
0,
352-
decoder_batch_ele_num * sizeof(int32_t),
353-
stream));
345+
CUDA_CHECK(cudaMemsetAsync(decoder_batch_ids.data<int>(),
346+
0,
347+
decoder_batch_ele_num * sizeof(int32_t),
348+
stream));
349+
CUDA_CHECK(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
350+
0,
351+
decoder_batch_ele_num * sizeof(int32_t),
352+
stream));
354353

355354
split_block_for_mla<<<1, 32, 0, stream>>>(
356355
seq_lens_this_time.data<int>(),
@@ -362,17 +361,15 @@ void GetBlockShapeAndSplitKVBlock(
362361
chunk_size);
363362

364363
} else {
365-
PADDLE_ENFORCE_GPU_SUCCESS(
366-
cudaMemsetAsync(decoder_batch_ids.data<int>(),
367-
0,
368-
decoder_batch_ele_num * sizeof(int32_t),
369-
stream));
370-
PADDLE_ENFORCE_GPU_SUCCESS(
371-
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
372-
0,
373-
decoder_batch_ele_num * sizeof(int32_t),
374-
stream));
375-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
364+
CUDA_CHECK(cudaMemsetAsync(decoder_batch_ids.data<int>(),
365+
0,
366+
decoder_batch_ele_num * sizeof(int32_t),
367+
stream));
368+
CUDA_CHECK(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
369+
0,
370+
decoder_batch_ele_num * sizeof(int32_t),
371+
stream));
372+
CUDA_CHECK(cudaMemsetAsync(
376373
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
377374

378375
split_q_block<<<1, 32, 0, stream>>>(
@@ -391,8 +388,6 @@ void GetBlockShapeAndSplitKVBlock(
391388
#endif
392389
decoder_num_blocks_cpu.copy_(
393390
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
394-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
395-
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
396391
}
397392
}
398393

@@ -401,19 +396,17 @@ void GetBlockShapeAndSplitKVBlock(
401396
const uint32_t max_tile_size_per_bs_kv =
402397
div_up(max_enc_dec_len_this_time, block_size);
403398
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
404-
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
399+
CUDA_CHECK(cudaMemsetAsync(
405400
kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
406-
PADDLE_ENFORCE_GPU_SUCCESS(
407-
cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(),
408-
0,
409-
kv_batch_shape * sizeof(int32_t),
410-
stream));
401+
CUDA_CHECK(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(),
402+
0,
403+
kv_batch_shape * sizeof(int32_t),
404+
stream));
411405
auto kv_num_blocks_x =
412406
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
413407

414408
split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
415409
seq_lens_decoder.data<int>(),
416-
// sequence_lengths->data<int>(),
417410
seq_lens_encoder.data<int>(),
418411
kv_batch_ids.data<int>(),
419412
kv_tile_ids_per_batch.data<int>(),
@@ -428,16 +421,14 @@ void GetBlockShapeAndSplitKVBlock(
428421
const uint32_t encoder_max_tile_size_per_bs_q =
429422
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
430423
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
431-
PADDLE_ENFORCE_GPU_SUCCESS(
432-
cudaMemsetAsync(encoder_batch_ids.data<int>(),
433-
0,
434-
encoder_batch_shape * sizeof(int32_t),
435-
stream));
436-
PADDLE_ENFORCE_GPU_SUCCESS(
437-
cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(),
438-
0,
439-
encoder_batch_shape * sizeof(int32_t),
440-
stream));
424+
CUDA_CHECK(cudaMemsetAsync(encoder_batch_ids.data<int>(),
425+
0,
426+
encoder_batch_shape * sizeof(int32_t),
427+
stream));
428+
CUDA_CHECK(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(),
429+
0,
430+
encoder_batch_shape * sizeof(int32_t),
431+
stream));
441432
auto encoder_num_blocks_x =
442433
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
443434
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,31 +72,34 @@ def allocate_launch_related_buffer(
7272
block_size,
7373
):
7474
# Initialize AttentionBackend buffers
75-
group_size = np.ceil(num_heads / kv_num_heads)
75+
assert num_heads % kv_num_heads == 0
76+
assert max_model_len % block_size == 0
77+
assert max_model_len % encoder_block_shape_q == 0
78+
group_size = num_heads // kv_num_heads
7679

7780
# NOTE: (changwenbin) When using auto_chunk,
7881
# decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
7982
decode_max_tile_size = (
80-
1024 * max_batch_size * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q)
83+
1024 * max_batch_size * (int)(np.ceil(decoder_step_token_num * group_size / decoder_block_shape_q))
8184
)
82-
encode_max_tile_size = max_batch_size * np.ceil((max_model_len * group_size) / encoder_block_shape_q)
83-
kv_max_tile_size = max_batch_size * np.ceil(max_model_len / block_size)
85+
encode_max_tile_size = max_batch_size * (max_model_len * group_size // encoder_block_shape_q)
86+
kv_max_tile_size = max_batch_size * (max_model_len // block_size)
8487
res = {}
85-
res["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
86-
res["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
88+
res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
89+
res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
8790
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
8891
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
8992
# adapted to cudagraph.
9093
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
9194
res["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
9295
res["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu()
9396

94-
res["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
95-
res["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
97+
res["encoder_batch_ids"] = paddle.full([encode_max_tile_size], 0, dtype="int32")
98+
res["encoder_tile_ids_per_batch"] = paddle.full([encode_max_tile_size], 0, dtype="int32")
9699
res["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
97100

98-
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
99-
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
101+
res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32")
102+
res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32")
100103
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
101104
return res
102105

tests/entrypoints/openai/test_run_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ async def test_initialize_engine_client(self, mock_engine_client):
172172
mock_args = Mock()
173173
mock_args.model = "test-model"
174174
mock_args.tokenizer = "test-tokenizer"
175-
mock_args.max_model_len = 1000
175+
mock_args.max_model_len = 1024
176176
mock_args.tensor_parallel_size = 1
177177
mock_args.engine_worker_queue_port = [8000]
178178
mock_args.local_data_parallel_id = 0
@@ -202,7 +202,7 @@ async def test_initialize_engine_client(self, mock_engine_client):
202202
def test_create_serving_handlers(self, mock_chat_handler, mock_model_handler):
203203
"""测试创建服务处理器"""
204204
mock_args = Mock()
205-
mock_args.max_model_len = 1000
205+
mock_args.max_model_len = 1024
206206
mock_args.ips = "127.0.0.1"
207207
mock_args.max_waiting_time = 60
208208
mock_args.enable_mm_output = False
@@ -1286,7 +1286,7 @@ def run_fastdeploy_command(self, input_content, port=None):
12861286
"--quantization",
12871287
"wint4",
12881288
"--max-model-len",
1289-
"4192",
1289+
"5120",
12901290
"--max-num-seqs",
12911291
"64",
12921292
"--load-choices",

0 commit comments

Comments
 (0)