Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &stop_token_ids,
const paddle::Tensor &stop_token_ids_len,
const paddle::Tensor &min_tokens,
const paddle::Tensor &max_tokens,
const bool beam_search);


Expand Down
89 changes: 86 additions & 3 deletions custom_ops/gpu_ops/stop_generation_multi_ends.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ __global__ void set_value_by_flags(bool *stop_flags,
const int *stop_seqs_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *stop_token_ids,
const int *stop_token_ids_len,
const int stop_token_ids_max_len,
const int64_t *min_tokens,
const int64_t *max_tokens,
bool beam_search,
bool prefill_one_step_stop) {
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;

if (bid < bs) {
if(tid == 0){
if (prefill_one_step_stop) {
Expand All @@ -62,13 +68,69 @@ __global__ void set_value_by_flags(bool *stop_flags,
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {

// check max_tokens
const int64_t current_step = step_idx[bid];
const int64_t max_token_limit = max_tokens[bid];

if (current_step >= max_token_limit) {
stop_flags[bid] = true;
next_tokens[bid] = end_ids[0];
topk_ids[bid] = end_ids[0];
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里max_tokens的限制是否必要,目前应该已经有针对max_tokens的限制了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是stop_token_ids如果设置到了max_tokens的外面,那么就截止到max_tokens,我认为应该是有必要的


// check min_tokens
const int64_t min_token_limit = min_tokens[bid];
const bool below_min_tokens = current_step < min_token_limit;

// If haven't reached min_tokens, cannot stop for any reason
if (below_min_tokens) {
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
return;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不足min_tokens的条件是不是所有情况都应该直接return

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
} else {
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
}

if(!stop_flags[bid] && stop_token_ids != nullptr && stop_token_ids_len != nullptr)
{
const int num_stop_tokens = stop_token_ids_len[bid];

if(num_stop_tokens > 0)
{
const int64_t last_token_id = topk_ids[bid];

if(last_token_id >= 0)
{
const int64_t *stop_tokens_now = stop_token_ids + bid * stop_token_ids_max_len;

for(int i = 0; i < num_stop_tokens; ++i)
{
const int64_t stop_token = stop_tokens_now[i];

if(stop_token >= 0 && last_token_id == stop_token)
{
stop_flags[bid] = true;
next_tokens[bid] = end_ids[0];
topk_ids[bid] = end_ids[0];
break;
}
}
}
}
}
}
}
// dealing stop_seqs

const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;

const int64_t current_step = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
if (current_step < min_token_limit) return;

const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];
Expand All @@ -90,6 +152,8 @@ __global__ void set_value_by_flags(bool *stop_flags,
}
}



void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens,
Expand All @@ -99,6 +163,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &stop_token_ids,
const paddle::Tensor &stop_token_ids_len,
const paddle::Tensor &min_tokens,
const paddle::Tensor &max_tokens,
const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
Expand All @@ -121,6 +189,16 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];

int stop_token_ids_max_len = 0;
const int64_t *stop_token_ids_ptr = nullptr;
const int *stop_token_ids_len_ptr = nullptr;
if(stop_token_ids.is_initialized() && stop_token_ids_len.is_initialized())
{
stop_token_ids_max_len = stop_token_ids.shape()[1]; // [bs,max_stop_tokens]
stop_token_ids_ptr = stop_token_ids.data<int64_t>();
stop_token_ids_len_ptr = stop_token_ids_len.data<int>();
}
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
Expand All @@ -137,12 +215,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
stop_token_ids_ptr,
stop_token_ids_len_ptr,
stop_token_ids_max_len,
min_tokens.data<int64_t>(),
max_tokens.data<int64_t>(),
beam_search,
prefill_one_step_stop);
}

PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len","stop_token_ids","stop_token_ids_len","min_tokens","max_tokens"})
.Attrs({"beam_search: bool"})
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def read_from_env(self):
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
self.stop_token_ids_max_len = int(envs.FD_STOP_TOKEN_IDS_MAX_LEN)

def reset_config_value(key, value):
if not hasattr(self, key.lower()):
Expand Down
8 changes: 4 additions & 4 deletions fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SamplingParams:
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated.
stop_seqs: list of strings that stop the generation when they are generated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里stopstop_seqs的修改,会改变对外暴露的接口名吗

The returned output will not contain the stop strings.
stop_token_ids: list of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
Expand Down Expand Up @@ -90,7 +90,7 @@ class SamplingParams:
top_k: int = 0
min_p: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_seqs: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
stop_seqs_len: Optional[int] = None
max_tokens: Optional[int] = None
Expand Down Expand Up @@ -127,7 +127,7 @@ def from_optional(
top_k,
min_p,
seed=None,
stop=None,
stop_seqs=None,
stop_token_ids=None,
max_tokens=None,
reasoning_max_tokens=None,
Expand All @@ -149,7 +149,7 @@ def from_optional(
top_k=top_k if top_k is not None else 0,
min_p=min_p if min_p is not None else 0.0,
seed=seed,
stop=stop,
stop_seqs=stop_seqs,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
reasoning_max_tokens=reasoning_max_tokens,
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"FD_MAX_STOP_SEQS_NUM": lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
# Maximum length of stop sequences.
"FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
# Maximum length of stop token ids.
"FD_STOP_TOKEN_IDS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
# GPU devices that will be used. This is a string that
# splited by comma, such as 0,1,2.
"CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/input/ernie4_5_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def process_request(self, request, max_model_len=None, **kwargs):
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)

# processing bad_words
Expand Down Expand Up @@ -177,7 +177,7 @@ def process_request_dict(self, request, max_model_len=None):
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

# processing bad_words
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def process_request_dict(self, request, max_model_len=None):
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

bad_words = request.get("bad_words")
Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/input/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def process_request(self, request, max_model_len=None, **kwargs):
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)

# processing bad_words
Expand Down Expand Up @@ -284,14 +284,15 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
"""
data_processor_logger.info(f"Start processing request dict: {request}")
request = self._apply_default_parameters(request)

if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids

# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

# processing bad_words
Expand Down
12 changes: 10 additions & 2 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,12 @@ def post_process_normal(
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs,
model_output.stop_seqs_len,
model_output.stop_token_ids,
model_output.stop_token_ids_len,
model_output.min_tokens,
model_output.max_tokens,
False,
) # multi ends
elif current_platform.is_maca():
Expand All @@ -262,8 +266,12 @@ def post_process_normal(
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs,
model_output.stop_seqs_len,
model_output.stop_token_ids,
model_output.stop_token_ids_len,
model_output.min_tokens,
model_output.max_tokens,
False,
) # multi ends
else:
Expand Down
Loading
Loading