-
Notifications
You must be signed in to change notification settings - Fork 659
[Feature] support stop_token_ids #4382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 22 commits
fecd8ef
4f24a2d
20e5b58
1c9a1e3
3d5c730
c26be0d
394a24a
b0dcbf7
39deb26
8d76a61
6e93c15
601d30c
78e9aa5
d067a6f
9ec111b
c2d8965
700bdcc
c116078
5cbfde4
870f192
82a2473
83a9f79
d04c1bb
c0aa25d
185aa4f
d582e5c
fe8d942
c721157
e49a056
f02fbff
d2ea29f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,59 +37,106 @@ __global__ void set_value_by_flags(bool *stop_flags, | |
| const int *stop_seqs_len, | ||
| const int stop_seqs_bs, | ||
| const int stop_seqs_max_len, | ||
| const int64_t *stop_token_ids, | ||
| const int *stop_token_ids_len, | ||
| const int stop_token_ids_max_len, | ||
| const int64_t *min_tokens, | ||
| bool beam_search, | ||
| bool prefill_one_step_stop) { | ||
| int tid = threadIdx.x; | ||
| int bid = blockIdx.x; | ||
| if (tid >= stop_seqs_bs) return; | ||
| if (bid < bs) { | ||
| if(tid == 0){ | ||
| if (prefill_one_step_stop) { | ||
| stop_flags[bid] = true; | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } else { | ||
| if (stop_flags[bid]) { | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } else { | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| } else { | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } | ||
| } | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| stop_flags[bid] = true; | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| } | ||
| // dealing stop_seqs | ||
| const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; | ||
| if (stop_seq_len <= 0) return; | ||
| const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; | ||
| const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; | ||
| const int64_t step_idx_now = step_idx[bid]; | ||
|
|
||
| bool is_end = true; | ||
| int count = 1; | ||
| for (int i = stop_seq_len - 1; i >= 0; --i) { | ||
| if ((step_idx_now - count) < 0 || | ||
| pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { | ||
| is_end = false; | ||
| break; | ||
| } | ||
| int tid = threadIdx.x; | ||
| int bid = blockIdx.x; | ||
| if (tid >= stop_seqs_bs) return; | ||
|
|
||
| if (bid < bs) { | ||
| if (tid == 0) { | ||
| if (prefill_one_step_stop) { | ||
| stop_flags[bid] = true; | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } | ||
| if (is_end) { | ||
| next_tokens[bid] = end_ids[0]; | ||
| stop_flags[bid] = true; | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } else { | ||
| if (stop_flags[bid]) { | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } else { | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| } else { | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } | ||
| } | ||
|
|
||
| const int64_t current_step = step_idx[bid]; | ||
| // check min_tokens | ||
| const int64_t min_token_limit = min_tokens[bid]; | ||
| const bool below_min_tokens = current_step < min_token_limit; | ||
|
|
||
| // If haven't reached min_tokens, cannot stop for any reason | ||
| if (below_min_tokens) { | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| return; | ||
| } | ||
| } else { | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| stop_flags[bid] = true; | ||
| } | ||
|
|
||
| if (!stop_flags[bid] && stop_token_ids != nullptr && | ||
| stop_token_ids_len != nullptr) { | ||
| const int num_stop_tokens = stop_token_ids_len[bid]; | ||
|
|
||
| if (num_stop_tokens > 0) { | ||
| const int64_t last_token_id = topk_ids[bid]; | ||
|
|
||
| if (last_token_id >= 0) { | ||
| const int64_t *stop_tokens_now = | ||
| stop_token_ids + bid * stop_token_ids_max_len; | ||
|
|
||
| for (int i = 0; i < num_stop_tokens; ++i) { | ||
| const int64_t stop_token = stop_tokens_now[i]; | ||
|
|
||
| if (stop_token >= 0 && last_token_id == stop_token) { | ||
| stop_flags[bid] = true; | ||
| next_tokens[bid] = end_ids[0]; | ||
| topk_ids[bid] = end_ids[0]; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; | ||
| if (stop_seq_len <= 0) return; | ||
|
|
||
| const int64_t current_step = step_idx[bid]; | ||
| const int64_t min_token_limit = min_tokens[bid]; | ||
| if (current_step < min_token_limit) return; | ||
|
|
||
| const int64_t *stop_seq_now = | ||
| stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; | ||
| const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; | ||
| const int64_t step_idx_now = step_idx[bid]; | ||
|
|
||
| bool is_end = true; | ||
| int count = 1; | ||
| for (int i = stop_seq_len - 1; i >= 0; --i) { | ||
| if ((step_idx_now - count) < 0 || | ||
| pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { | ||
| is_end = false; | ||
| break; | ||
| } | ||
| } | ||
| if (is_end) { | ||
| next_tokens[bid] = end_ids[0]; | ||
| stop_flags[bid] = true; | ||
| topk_ids[bid] = end_ids[0]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | ||
|
|
@@ -101,50 +148,79 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | |
| const paddle::Tensor &step_idx, | ||
| const paddle::Tensor &stop_seqs, | ||
| const paddle::Tensor &stop_seqs_len, | ||
| const paddle::Tensor &stop_token_ids, | ||
| const paddle::Tensor &stop_token_ids_len, | ||
| const paddle::Tensor &min_tokens, | ||
| const bool beam_search) { | ||
| PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); | ||
| PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); | ||
| bool prefill_one_step_stop = false; | ||
| if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { | ||
| // std::cout << "Your PATH is: " << env_p << '\n'; | ||
| if (env_p[0] == '1') { | ||
| prefill_one_step_stop = true; | ||
| } | ||
| PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); | ||
| PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); | ||
| bool prefill_one_step_stop = false; | ||
| if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { | ||
| // std::cout << "Your PATH is: " << env_p << '\n'; | ||
|
||
| if (env_p[0] == '1') { | ||
| prefill_one_step_stop = true; | ||
| } | ||
| } | ||
|
|
||
| #ifdef PADDLE_WITH_CUSTOM_DEVICE | ||
| auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place())); | ||
| auto cu_stream = dev_ctx->stream(); | ||
| auto dev_ctx = static_cast<const phi::CustomContext *>( | ||
| paddle::experimental::DeviceContextPool::Instance().Get( | ||
| topk_ids.place())); | ||
| auto cu_stream = dev_ctx->stream(); | ||
| #else | ||
| auto cu_stream = topk_ids.stream(); | ||
| auto cu_stream = topk_ids.stream(); | ||
| #endif | ||
| std::vector<int64_t> shape = topk_ids.shape(); | ||
| int64_t bs_now = shape[0]; | ||
| int64_t end_length = end_ids.shape()[0]; | ||
| int stop_seqs_bs = stop_seqs.shape()[1]; | ||
| int stop_seqs_max_len = stop_seqs.shape()[2]; | ||
| int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
| set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>( | ||
| const_cast<bool *>(stop_flags.data<bool>()), | ||
| const_cast<int64_t *>(topk_ids.data<int64_t>()), | ||
| const_cast<int64_t *>(next_tokens.data<int64_t>()), | ||
| end_ids.data<int64_t>(), | ||
| seq_lens.data<int>(), | ||
| bs_now, | ||
| end_length, | ||
| pre_ids.data<int64_t>(), | ||
| pre_ids.shape()[1], | ||
| step_idx.data<int64_t>(), | ||
| stop_seqs.data<int64_t>(), | ||
| stop_seqs_len.data<int>(), | ||
| stop_seqs_bs, | ||
| stop_seqs_max_len, | ||
| beam_search, | ||
| prefill_one_step_stop); | ||
| std::vector<int64_t> shape = topk_ids.shape(); | ||
| int64_t bs_now = shape[0]; | ||
| int64_t end_length = end_ids.shape()[0]; | ||
| int stop_seqs_bs = stop_seqs.shape()[1]; | ||
| int stop_seqs_max_len = stop_seqs.shape()[2]; | ||
|
|
||
| int stop_token_ids_max_len = 0; | ||
| const int64_t *stop_token_ids_ptr = nullptr; | ||
| const int *stop_token_ids_len_ptr = nullptr; | ||
| if (stop_token_ids.is_initialized() && stop_token_ids_len.is_initialized()) { | ||
| stop_token_ids_max_len = stop_token_ids.shape()[1]; // [bs,max_stop_tokens] | ||
| stop_token_ids_ptr = stop_token_ids.data<int64_t>(); | ||
| stop_token_ids_len_ptr = stop_token_ids_len.data<int>(); | ||
| } | ||
| int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
| set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>( | ||
| const_cast<bool *>(stop_flags.data<bool>()), | ||
| const_cast<int64_t *>(topk_ids.data<int64_t>()), | ||
| const_cast<int64_t *>(next_tokens.data<int64_t>()), | ||
| end_ids.data<int64_t>(), | ||
| seq_lens.data<int>(), | ||
| bs_now, | ||
| end_length, | ||
| pre_ids.data<int64_t>(), | ||
| pre_ids.shape()[1], | ||
| step_idx.data<int64_t>(), | ||
| stop_seqs.data<int64_t>(), | ||
| stop_seqs_len.data<int>(), | ||
| stop_seqs_bs, | ||
| stop_seqs_max_len, | ||
| stop_token_ids_ptr, | ||
| stop_token_ids_len_ptr, | ||
| stop_token_ids_max_len, | ||
| min_tokens.data<int64_t>(), | ||
| beam_search, | ||
| prefill_one_step_stop); | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(set_stop_value_multi_ends) | ||
| .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"}) | ||
| .Inputs({"topk_ids", | ||
| "stop_flags", | ||
| "seq_lens", | ||
| "end_ids", | ||
| "next_tokens", | ||
| "pre_ids", | ||
| "step_idx", | ||
| "stop_seqs", | ||
| "stop_seqs_len", | ||
| "stop_token_ids", | ||
| "stop_token_ids_len", | ||
| "min_tokens"}) | ||
| .Attrs({"beam_search: bool"}) | ||
| .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) | ||
| .SetInplaceMap({{"topk_ids", "topk_ids_out"}, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -37,7 +37,9 @@ | |||||
| # Maximum number of stop sequences. | ||||||
| "FD_MAX_STOP_SEQS_NUM": lambda: int(os.getenv("FD_MAX_STOP_SEQS_NUM", "5")), | ||||||
| # Maximum length of stop sequences. | ||||||
| "FD_STOP_SEQS_MAX_LEN": lambda: int(os.getenv("FD_STOP_SEQS_MAX_LEN", "8")), | ||||||
| "FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"), | ||||||
| # Maximum length of stop token ids. | ||||||
| "FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"), | ||||||
|
||||||
| "FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"), | |
| "FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_TOKEN_IDS_MAX_LEN", "8"), |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -216,7 +216,7 @@ def process_request(self, request, max_model_len=None, **kwargs): | |
| stop_sequences = request.get("stop", []) | ||
| if stop_sequences is not None and len(stop_sequences) != 0: | ||
| stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) | ||
| request.set("stop_token_ids", stop_seqs) | ||
| request.set("stop_seqs", stop_seqs) | ||
|
||
| request.set("stop_seqs_len", stop_seqs_len) | ||
|
|
||
| # processing bad_words | ||
|
|
@@ -284,14 +284,15 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): | |
| """ | ||
| data_processor_logger.info(f"Start processing request dict: {request}") | ||
| request = self._apply_default_parameters(request) | ||
|
|
||
| if not request.get("eos_token_ids"): | ||
| request["eos_token_ids"] = self.eos_token_ids | ||
|
|
||
| # processing stop_sequences | ||
| stop_sequences = request.get("stop", []) | ||
| if stop_sequences: | ||
| stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) | ||
| request["stop_token_ids"] = stop_seqs | ||
| request["stop"] = stop_seqs | ||
| request["stop_seqs_len"] = stop_seqs_len | ||
|
|
||
| # processing bad_words | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是 满足
below_min_tokens都会返回There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done