-
Notifications
You must be signed in to change notification settings - Fork 964
Fix unsigned integer underflow issue with truncation #1859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
When the truncation max_len is shorter than the number of added tokens there is an underflow issue even when the user didn't ask to add special tokens. Signed-off-by: Max de Bayser <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the bug report + PR.
I'm not sure the fix you propose is the optimal solution. To understand the problem space better, are you still using added tokens at this point ? Wouldn't it be even easier to simply remove the added tokens by removing the post_processor from the tokenizer if you are not using them ?
@@ -1216,7 +1206,7 @@ where | |||
|
|||
if add_special_tokens && n_added_tokens > 0 { | |||
let params = TruncationParams { | |||
max_length: trunc.max_length - n_added_tokens, | |||
max_length: if n_added_tokens > trunc.max_length {0} else {trunc.max_length - n_added_tokens}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length: if n_added_tokens > trunc.max_length {0} else {trunc.max_length - n_added_tokens}, | |
max_length: trunc.max_length.checked_sub(n_added_tokens).unwrap_or(0) |
NIT: I feel like this is more readable
@@ -506,7 +506,7 @@ impl DerefMut for Tokenizer { | |||
|
|||
#[derive(thiserror::Error, Debug)] | |||
#[error("{0}")] | |||
pub struct TruncationParamError(String); | |||
pub struct TruncationParamError(pub String); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No pub here please. Create a construction if needed.
@@ -619,16 +619,6 @@ where | |||
/// | |||
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()` | |||
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> { | |||
if let Some(trunc_params) = &trunc { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel great about modifying this sanitation.
It seems to me that one should be aware of the added tokens as it's the standard way to tokenizer defined by the tokenizer's creator. So preventing chunking below those is kind of important.
If someone REALLY wants super low chunking and wants to ignore the added tokens. It seems specific enough that simply changing the post_processor for None would be much simpler at this points.
So we can keep this footgun check alive, and power users can still modify the tokenizers' behavior.
@@ -619,16 +619,6 @@ where | |||
/// | |||
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()` | |||
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> { | |||
if let Some(trunc_params) = &trunc { | |||
let n_added_tokens = self.get_n_added_tokens(false); | |||
let effective_max_length = trunc_params.max_length - n_added_tokens; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be checked_sub so we don't panic either.
When the truncation max_len is shorter than the number of added tokens there is an underflow issue even when the user didn't ask to add special tokens.
For example, this code here:
Fails with the following error:
This error happens with the ibm-granite model because it's a RobertaModel which adds 2 special tokens. With LLama this issue does not happen. I found this problem in the context of this vllm issue: vllm-project/vllm#22635 .
The idea of this PR is to move the verification that contains the code which is susceptible to this problem from the initialization to the actual encode call, where it can take into account the value of
add_special_tokens
.