Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/benchmark/static_inference/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,23 @@ def prefill(
b_ready_cache_len,
):
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len
model_input = ModelInput(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
max_q_seq_len=max_len_in_batch,
max_kv_seq_len=max_len_in_batch,
max_cache_len=0,
input_ids=input_ids,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
mem_indexes_cpu=mem_indexes,
is_prefill=True,
b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len
b_prefill_start_loc=b_prefill_start_loc,
prefix_total_token_num=0, # the default kvcache len is zero.
)

model_output = model_part.forward(model_input)
Expand All @@ -209,6 +215,8 @@ def decode(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
max_q_seq_len=1,
max_kv_seq_len=max_len_in_batch,
input_ids=input_ids,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
Expand Down
Binary file added test/benchmark/static_inference/test.npy
Binary file not shown.