Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fecd8ef
fix stop_seqs
lizexu123 Oct 13, 2025
4f24a2d
support stop_token_ids
lizexu123 Oct 14, 2025
20e5b58
merge develop
lizexu123 Oct 16, 2025
1c9a1e3
support min_tokens
lizexu123 Oct 16, 2025
3d5c730
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Oct 16, 2025
c26be0d
support min_tokens and max_tokens stop
lizexu123 Oct 17, 2025
394a24a
add test
lizexu123 Oct 17, 2025
b0dcbf7
add FD_STOP_TOKEN_IDS_MAX_LEN
lizexu123 Oct 17, 2025
39deb26
merge develop
lizexu123 Oct 17, 2025
8d76a61
delete print
lizexu123 Oct 17, 2025
6e93c15
fix dummy_run
lizexu123 Oct 17, 2025
601d30c
update
lizexu123 Oct 31, 2025
78e9aa5
fix stop
lizexu123 Oct 31, 2025
d067a6f
code-prefix
lizexu123 Oct 31, 2025
9ec111b
fix max_tokens
lizexu123 Oct 31, 2025
c2d8965
update
lizexu123 Nov 3, 2025
700bdcc
delete max_tokens
lizexu123 Nov 5, 2025
c116078
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Nov 5, 2025
5cbfde4
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Nov 7, 2025
870f192
delete max_tokens
lizexu123 Nov 7, 2025
82a2473
fix
lizexu123 Nov 7, 2025
83a9f79
fix
lizexu123 Nov 12, 2025
d04c1bb
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Nov 12, 2025
c0aa25d
update develop
lizexu123 Nov 12, 2025
185aa4f
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Nov 13, 2025
d582e5c
add stop_token_ids test
lizexu123 Nov 13, 2025
fe8d942
fix
lizexu123 Nov 13, 2025
c721157
add ducument
lizexu123 Nov 13, 2025
e49a056
fix document
lizexu123 Nov 13, 2025
f02fbff
fix test
lizexu123 Nov 13, 2025
d2ea29f
delete print
lizexu123 Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& stop_token_ids,
const paddle::Tensor& stop_token_ids_len,
const paddle::Tensor& min_tokens,
const bool beam_search);

void UpdateInputs(const paddle::Tensor& stop_flags,
Expand Down
240 changes: 158 additions & 82 deletions custom_ops/gpu_ops/stop_generation_multi_ends.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,59 +37,106 @@ __global__ void set_value_by_flags(bool *stop_flags,
const int *stop_seqs_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *stop_token_ids,
const int *stop_token_ids_len,
const int stop_token_ids_max_len,
const int64_t *min_tokens,
bool beam_search,
bool prefill_one_step_stop) {
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
if(tid == 0){
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
}
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];

bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;

if (bid < bs) {
if (tid == 0) {
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}

const int64_t current_step = step_idx[bid];
// check min_tokens
const int64_t min_token_limit = min_tokens[bid];
const bool below_min_tokens = current_step < min_token_limit;

// If haven't reached min_tokens, cannot stop for any reason
if (below_min_tokens) {
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
return;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是 满足 below_min_tokens都会返回

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
} else {
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
}

if (!stop_flags[bid] && stop_token_ids != nullptr &&
stop_token_ids_len != nullptr) {
const int num_stop_tokens = stop_token_ids_len[bid];

if (num_stop_tokens > 0) {
const int64_t last_token_id = topk_ids[bid];

if (last_token_id >= 0) {
const int64_t *stop_tokens_now =
stop_token_ids + bid * stop_token_ids_max_len;

for (int i = 0; i < num_stop_tokens; ++i) {
const int64_t stop_token = stop_tokens_now[i];

if (stop_token >= 0 && last_token_id == stop_token) {
stop_flags[bid] = true;
next_tokens[bid] = end_ids[0];
topk_ids[bid] = end_ids[0];
break;
}
}
}
}
}
}
}

const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;

const int64_t current_step = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
if (current_step < min_token_limit) return;

const int64_t *stop_seq_now =
stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];

bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
}
}
}

void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
Expand All @@ -101,50 +148,79 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &stop_token_ids,
const paddle::Tensor &stop_token_ids_len,
const paddle::Tensor &min_tokens,
const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的打印调试代码也一并删掉吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}

#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = topk_ids.stream();
auto cu_stream = topk_ids.stream();
#endif
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
beam_search,
prefill_one_step_stop);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];

int stop_token_ids_max_len = 0;
const int64_t *stop_token_ids_ptr = nullptr;
const int *stop_token_ids_len_ptr = nullptr;
if (stop_token_ids.is_initialized() && stop_token_ids_len.is_initialized()) {
stop_token_ids_max_len = stop_token_ids.shape()[1]; // [bs,max_stop_tokens]
stop_token_ids_ptr = stop_token_ids.data<int64_t>();
stop_token_ids_len_ptr = stop_token_ids_len.data<int>();
}
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
stop_token_ids_ptr,
stop_token_ids_len_ptr,
stop_token_ids_max_len,
min_tokens.data<int64_t>(),
beam_search,
prefill_one_step_stop);
}

PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
.Inputs({"topk_ids",
"stop_flags",
"seq_lens",
"end_ids",
"next_tokens",
"pre_ids",
"step_idx",
"stop_seqs",
"stop_seqs_len",
"stop_token_ids",
"stop_token_ids_len",
"min_tokens"})
.Attrs({"beam_search: bool"})
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,9 @@ def read_from_env(self):
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
self.stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
self.stop_token_ids_max_len = int(envs.FD_STOP_TOKEN_IDS_MAX_LEN)

def reset_config_value(key, value):
if not hasattr(self, key.lower()):
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
# Maximum number of stop sequences.
"FD_MAX_STOP_SEQS_NUM": lambda: int(os.getenv("FD_MAX_STOP_SEQS_NUM", "5")),
# Maximum length of stop sequences.
"FD_STOP_SEQS_MAX_LEN": lambda: int(os.getenv("FD_STOP_SEQS_MAX_LEN", "8")),
"FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
# Maximum length of stop token ids.
"FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The environment variable name is incorrect. It reads from FD_STOP_SEQS_MAX_LEN instead of FD_STOP_TOKEN_IDS_MAX_LEN. This will cause stop_token_ids to use the wrong environment variable, making it impossible to configure them separately from stop sequences. Change to: lambda: os.getenv("FD_STOP_TOKEN_IDS_MAX_LEN", "8")

Suggested change
"FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
"FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_TOKEN_IDS_MAX_LEN", "8"),

Copilot uses AI. Check for mistakes.
# GPU devices that will be used. This is a string that
# splited by comma, such as 0,1,2.
"CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/input/ernie4_5_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def process_request(self, request, max_model_len=None, **kwargs):
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)

# processing bad_words
Expand Down Expand Up @@ -177,7 +177,7 @@ def process_request_dict(self, request, max_model_len=None):
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

# processing bad_words
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def process_request_dict(self, request, max_model_len=None):
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

bad_words = request.get("bad_words")
Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/input/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def process_request(self, request, max_model_len=None, **kwargs):
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs", stop_seqs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是应该和其他的地方统一一下,都叫stop 或者stop_seqs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

request.set("stop_seqs_len", stop_seqs_len)

# processing bad_words
Expand Down Expand Up @@ -284,14 +284,15 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
"""
data_processor_logger.info(f"Start processing request dict: {request}")
request = self._apply_default_parameters(request)

if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids

# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

# processing bad_words
Expand Down
11 changes: 9 additions & 2 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def post_process_normal(
),
model_output.step_idx,
)

length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
paddle.assign(
paddle.logical_or(model_output.stop_flags, length_cond),
Expand All @@ -358,8 +359,11 @@ def post_process_normal(
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs,
model_output.stop_seqs_len,
model_output.stop_token_ids,
model_output.stop_token_ids_len,
model_output.min_tokens,
False,
) # multi ends
elif current_platform.is_maca():
Expand All @@ -371,8 +375,11 @@ def post_process_normal(
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs,
model_output.stop_seqs_len,
model_output.stop_token_ids,
model_output.stop_token_ids_len,
model_output.min_tokens,
False,
) # multi ends
else:
Expand Down
Loading
Loading