-
Notifications
You must be signed in to change notification settings - Fork 647
[Feature] support stop_token_ids #4382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 11 commits
fecd8ef
4f24a2d
20e5b58
1c9a1e3
3d5c730
c26be0d
394a24a
b0dcbf7
39deb26
8d76a61
6e93c15
601d30c
78e9aa5
d067a6f
9ec111b
c2d8965
700bdcc
c116078
5cbfde4
870f192
82a2473
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,11 +37,17 @@ __global__ void set_value_by_flags(bool *stop_flags, | |
| const int *stop_seqs_len, | ||
| const int stop_seqs_bs, | ||
| const int stop_seqs_max_len, | ||
| const int64_t *stop_token_ids, | ||
| const int *stop_token_ids_len, | ||
| const int stop_token_ids_max_len, | ||
| const int64_t *min_tokens, | ||
| const int64_t *max_tokens, | ||
| bool beam_search, | ||
| bool prefill_one_step_stop) { | ||
| int tid = threadIdx.x; | ||
| int bid = blockIdx.x; | ||
| if (tid >= stop_seqs_bs) return; | ||
|
|
||
| if (bid < bs) { | ||
| if(tid == 0){ | ||
| if (prefill_one_step_stop) { | ||
|
|
@@ -62,13 +68,69 @@ __global__ void set_value_by_flags(bool *stop_flags, | |
| next_tokens[bid] = topk_ids[bid]; | ||
| } | ||
| } | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
|
|
||
| // check max_tokens | ||
| const int64_t current_step = step_idx[bid]; | ||
| const int64_t max_token_limit = max_tokens[bid]; | ||
|
|
||
| if (current_step >= max_token_limit) { | ||
| stop_flags[bid] = true; | ||
| next_tokens[bid] = end_ids[0]; | ||
| topk_ids[bid] = end_ids[0]; | ||
| return; | ||
| } | ||
|
|
||
| // check min_tokens | ||
| const int64_t min_token_limit = min_tokens[bid]; | ||
| const bool below_min_tokens = current_step < min_token_limit; | ||
|
|
||
| // If haven't reached min_tokens, cannot stop for any reason | ||
| if (below_min_tokens) { | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| return; | ||
|
||
| } | ||
| } else { | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| stop_flags[bid] = true; | ||
| } | ||
|
|
||
| if(!stop_flags[bid] && stop_token_ids != nullptr && stop_token_ids_len != nullptr) | ||
| { | ||
| const int num_stop_tokens = stop_token_ids_len[bid]; | ||
|
|
||
| if(num_stop_tokens > 0) | ||
| { | ||
| const int64_t last_token_id = topk_ids[bid]; | ||
|
|
||
| if(last_token_id >= 0) | ||
| { | ||
| const int64_t *stop_tokens_now = stop_token_ids + bid * stop_token_ids_max_len; | ||
|
|
||
| for(int i = 0; i < num_stop_tokens; ++i) | ||
| { | ||
| const int64_t stop_token = stop_tokens_now[i]; | ||
|
|
||
| if(stop_token >= 0 && last_token_id == stop_token) | ||
| { | ||
| stop_flags[bid] = true; | ||
| next_tokens[bid] = end_ids[0]; | ||
| topk_ids[bid] = end_ids[0]; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| // dealing stop_seqs | ||
|
|
||
| const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; | ||
| if (stop_seq_len <= 0) return; | ||
|
|
||
| const int64_t current_step = step_idx[bid]; | ||
| const int64_t min_token_limit = min_tokens[bid]; | ||
| if (current_step < min_token_limit) return; | ||
|
|
||
| const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; | ||
| const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; | ||
| const int64_t step_idx_now = step_idx[bid]; | ||
|
|
@@ -90,6 +152,8 @@ __global__ void set_value_by_flags(bool *stop_flags, | |
| } | ||
| } | ||
|
|
||
|
|
||
|
|
||
| void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | ||
| const paddle::Tensor &stop_flags, | ||
| const paddle::Tensor &seq_lens, | ||
|
|
@@ -99,6 +163,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | |
| const paddle::Tensor &step_idx, | ||
| const paddle::Tensor &stop_seqs, | ||
| const paddle::Tensor &stop_seqs_len, | ||
| const paddle::Tensor &stop_token_ids, | ||
| const paddle::Tensor &stop_token_ids_len, | ||
| const paddle::Tensor &min_tokens, | ||
| const paddle::Tensor &max_tokens, | ||
| const bool beam_search) { | ||
| PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); | ||
| PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); | ||
|
|
@@ -121,6 +189,16 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | |
| int64_t end_length = end_ids.shape()[0]; | ||
| int stop_seqs_bs = stop_seqs.shape()[1]; | ||
| int stop_seqs_max_len = stop_seqs.shape()[2]; | ||
|
|
||
| int stop_token_ids_max_len = 0; | ||
| const int64_t *stop_token_ids_ptr = nullptr; | ||
| const int *stop_token_ids_len_ptr = nullptr; | ||
| if(stop_token_ids.is_initialized() && stop_token_ids_len.is_initialized()) | ||
| { | ||
| stop_token_ids_max_len = stop_token_ids.shape()[1]; // [bs,max_stop_tokens] | ||
| stop_token_ids_ptr = stop_token_ids.data<int64_t>(); | ||
| stop_token_ids_len_ptr = stop_token_ids_len.data<int>(); | ||
| } | ||
| int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
| set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>( | ||
| const_cast<bool *>(stop_flags.data<bool>()), | ||
|
|
@@ -137,12 +215,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | |
| stop_seqs_len.data<int>(), | ||
| stop_seqs_bs, | ||
| stop_seqs_max_len, | ||
| stop_token_ids_ptr, | ||
| stop_token_ids_len_ptr, | ||
| stop_token_ids_max_len, | ||
| min_tokens.data<int64_t>(), | ||
| max_tokens.data<int64_t>(), | ||
| beam_search, | ||
| prefill_one_step_stop); | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(set_stop_value_multi_ends) | ||
| .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"}) | ||
| .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len","stop_token_ids","stop_token_ids_len","min_tokens","max_tokens"}) | ||
| .Attrs({"beam_search: bool"}) | ||
| .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) | ||
| .SetInplaceMap({{"topk_ids", "topk_ids_out"}, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,7 +58,7 @@ class SamplingParams: | |
| considered, relative to the probability of the most likely token. | ||
| Must be in [0, 1]. Set to 0 to disable this. | ||
| seed: Random seed to use for the generation. | ||
| stop: list of strings that stop the generation when they are generated. | ||
| stop_seqs: list of strings that stop the generation when they are generated. | ||
|
||
| The returned output will not contain the stop strings. | ||
| stop_token_ids: list of tokens that stop the generation when they are | ||
| generated. The returned output will contain the stop tokens unless | ||
|
|
@@ -90,7 +90,7 @@ class SamplingParams: | |
| top_k: int = 0 | ||
| min_p: float = 0.0 | ||
| seed: Optional[int] = None | ||
| stop: Optional[Union[str, List[str]]] = None | ||
| stop_seqs: Optional[Union[str, List[str]]] = None | ||
| stop_token_ids: Optional[List[int]] = None | ||
| stop_seqs_len: Optional[int] = None | ||
| max_tokens: Optional[int] = None | ||
|
|
@@ -127,7 +127,7 @@ def from_optional( | |
| top_k, | ||
| min_p, | ||
| seed=None, | ||
| stop=None, | ||
| stop_seqs=None, | ||
| stop_token_ids=None, | ||
| max_tokens=None, | ||
| reasoning_max_tokens=None, | ||
|
|
@@ -149,7 +149,7 @@ def from_optional( | |
| top_k=top_k if top_k is not None else 0, | ||
| min_p=min_p if min_p is not None else 0.0, | ||
| seed=seed, | ||
| stop=stop, | ||
| stop_seqs=stop_seqs, | ||
| stop_token_ids=stop_token_ids, | ||
| max_tokens=max_tokens if max_tokens is not None else 8192, | ||
| reasoning_max_tokens=reasoning_max_tokens, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里max_tokens的限制是否必要,目前应该已经有针对max_tokens的限制了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是stop_token_ids如果设置到了max_tokens的外面,那么就截止到max_tokens,我认为应该是有必要的