@@ -68,23 +68,34 @@ impl CoreBPE {
68
68
fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < Rank > {
69
69
py. allow_threads ( || {
70
70
match std:: str:: from_utf8 ( bytes) {
71
+ // Straightforward case
71
72
Ok ( text) => self . encode_ordinary ( text) ,
73
+ // Oops, don't actually have UTF-8. But we need to do the regex splitting in
74
+ // Unicode space, so we make our best guess at where we would have splits
72
75
Err ( e) => {
73
76
let text = unsafe { std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] ) } ;
74
77
let ( tokens, last_piece_token_len) = self . encode ( text, & HashSet :: new ( ) ) ;
75
78
let ( mut tokens, last_piece_token_len) =
76
79
self . _increase_last_piece_token_len ( tokens, last_piece_token_len) ;
80
+
81
+ let mut unstable_bytes;
77
82
if !tokens. is_empty ( ) && last_piece_token_len > 0 {
78
83
// Lop off the tokens from the last piece and run BPE on the remaining bytes
79
- // Somewhat niche, but this may not be correct if we'd have had a regex
80
- // split between the valid UTF-8 and the invalid bytes, which is why this
81
- // method is private
82
- let mut unstable_bytes = self
84
+ // This likely matches what models see better, e.g. if you assume we're
85
+ // dealing with truncated UTF-8 bytes.
86
+ // Niche, but note this may not be correct if we'd have had a regex
87
+ // split between the valid UTF-8 and the invalid bytes.
88
+ unstable_bytes = self
83
89
. decode_bytes ( & tokens[ tokens. len ( ) - last_piece_token_len..] )
84
90
. unwrap ( ) ;
85
91
unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
86
92
87
93
tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
94
+ } else {
95
+ unstable_bytes = bytes[ e. valid_up_to ( ) ..] . to_vec ( ) ;
96
+ }
97
+
98
+ if !unstable_bytes. is_empty ( ) {
88
99
match self . encoder . get ( & unstable_bytes) {
89
100
Some ( token) => tokens. push ( * token) ,
90
101
None => {
0 commit comments