diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index cedabeebc..1a2abd160 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -506,7 +506,7 @@ impl DerefMut for Tokenizer { #[derive(thiserror::Error, Debug)] #[error("{0}")] -pub struct TruncationParamError(String); +pub struct TruncationParamError(pub String); /// A `Tokenizer` is capable of encoding/decoding any text. #[derive(Clone, Debug)] @@ -619,16 +619,6 @@ where /// /// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()` pub fn with_truncation(&mut self, trunc: Option) -> Result<&mut Self> { - if let Some(trunc_params) = &trunc { - let n_added_tokens = self.get_n_added_tokens(false); - let effective_max_length = trunc_params.max_length - n_added_tokens; - if effective_max_length < trunc_params.stride { - return Err(Box::new(TruncationParamError(format!( - "tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ", - trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens - )))); - } - } self.truncation = trunc; Ok(self) } @@ -1216,7 +1206,7 @@ where if add_special_tokens && n_added_tokens > 0 { let params = TruncationParams { - max_length: trunc.max_length - n_added_tokens, + max_length: if n_added_tokens > trunc.max_length {0} else {trunc.max_length - n_added_tokens}, ..*trunc }; truncate_encodings(encoding, pair_encoding, ¶ms)? diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index e9b392d2e..c4e5347e3 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,4 +1,4 @@ -use crate::tokenizer::{Encoding, Result}; +use crate::tokenizer::{Encoding, Result, TruncationParamError}; use serde::{Deserialize, Serialize}; use std::cmp; use std::mem; @@ -96,6 +96,13 @@ pub fn truncate_encodings( return Ok((encoding, pair_encoding)); }; + if params.stride > params.max_length { + return Err(Box::new(TruncationParamError(format!( + "tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= original max length - added special tokens), ", + params.stride, params.max_length + )))); + } + match params.strategy { TruncationStrategy::LongestFirst => { if let Some(other_encoding) = pair_encoding.as_mut() {